22
33const std = @import ("std" );
44
5- /// Operations supported by the engine
6- pub const Operation = enum (u8 ) {
7- /// Add two values
8- ADD ,
9- /// Subtract two values
10- SUB ,
11- /// Multiply two values
12- MUL ,
13- /// Divide two values
14- DIV ,
15-
16- /// Convert the Operation to its mathematical symbol
17- pub fn toString (self : Operation ) []const u8 {
18- return switch (self ) {
19- .ADD = > "+" ,
20- .SUB = > "-" ,
21- .MUL = > "*" ,
22- .DIV = > "/" ,
23- };
24- }
25- };
26-
27- pub fn info (T : type ) std.builtin.Type.Vector {
28- if (@typeInfo (T ) != .vector ) @compileError ("Expected a @Vector type got: " ++ @typeName (T ));
29- return @typeInfo (T ).vector ;
30- }
31-
325/// Represents a singular Scalar value
336pub fn Value (comptime T : type ) type {
7+ if (@typeInfo (T ) != .int and @typeInfo (T ) != .float ) {
8+ @compileError ("Expected @int or @float type, got: " ++ @typeName (T ));
9+ }
10+
3411 return struct {
3512 const Self = @This ();
3613 /// The value
3714 data : T ,
3815 /// The gradient
3916 grad : T ,
4017 /// Function for backpropagation
41- backprop : ? * const fn (self : * Self ) void ,
18+ backward : ? * const fn (self : * Self ) void ,
4219 /// The children used to compute the value
4320 prev : ? []* Self ,
4421 /// The operation that produced the value
45- operation : ? Operation ,
22+ operation : ? [] const u8 ,
4623 /// The label of the value
4724 label : ? []const u8 ,
4825
4926 /// Initialize the Value
50- pub fn init (data : T , prev : ? []* Self , operation : ? Operation , label : ? []const u8 ) Self {
27+ pub fn init (data : T , prev : ? []* Self , operation : ? [] const u8 , label : ? []const u8 ) Self {
5128 return Self {
5229 .data = data ,
5330 .grad = 0 ,
54- .backprop = null ,
31+ .backward = null ,
5532 .prev = prev ,
5633 .operation = operation ,
5734 .label = label ,
5835 };
5936 }
6037
61- pub fn add (self : * Self , other : * Self , allocator : std.mem.Allocator ) ! Self {
62- const children = try allocator .dupe (* Self , &.{ self , other });
63- const AddBackward = struct {
64- fn call (result : * Self ) void {
65- if (result .prev ) | prev_children | {
66- prev_children [0 ].grad += result .grad ;
67- prev_children [1 ].grad += result .grad ;
68- }
69- }
70- }.call ;
71-
72- return Self {
73- .data = self .data + other .data ,
74- .grad = 0 ,
75- .backprop = AddBackward ,
76- .prev = children ,
77- .operation = Operation .ADD ,
78- .label = null ,
79- };
80- }
81-
82- pub fn mul (self : * Self , other : * Self , allocator : std.mem.Allocator ) ! Self {
83- const children = try allocator .dupe (* Self , &.{ self , other });
84- const MulBackward = struct {
85- fn call (result : * Self ) void {
86- if (result .prev ) | prev_children | {
87- prev_children [0 ].grad += prev_children [1 ].data * result .grad ;
88- prev_children [1 ].grad += prev_children [0 ].data * result .grad ;
89- }
90- }
91- }.call ;
92-
93- return Self {
94- .data = self .data * other .data ,
95- .grad = 0 ,
96- .backprop = MulBackward ,
97- .prev = children ,
98- .operation = Operation .MUL ,
99- .label = null ,
100- };
101- }
102-
103- pub fn backward (self : * Self ) void {
104- if (self .backprop ) | bp | bp (self );
105- }
106-
10738 /// Convert the Value to a string
10839 pub fn toString (self : Self ) []const u8 {
10940 const op_name = if (self .operation ) | op | @tagName (op ) else "null" ;
@@ -125,47 +56,113 @@ pub fn Value(comptime T: type) type {
12556 return std .fmt .allocPrint (std .heap .page_allocator , "Value(data={any}, grad={any}, prev={s}, operation={s}, label={s})" , .{ self .data , self .grad , prev_str , op_name , label_name }) catch unreachable ;
12657 }
12758
128- /// Subtract two values
129- pub fn sub (self : * Self , other : * Self , allocator : std.mem.Allocator ) ! Self {
130- const children = try allocator .dupe (* Self , &.{ self , other });
131- const SubBackward = struct {
59+ pub fn add (self : * Self , other : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
60+ return Self {
61+ .data = self .data + other .data ,
62+ .grad = 0 ,
63+ .backward = struct {
64+ fn call (result : * Self ) void {
65+ if (result .prev ) | prev_children | {
66+ prev_children [0 ].grad += result .grad ;
67+ prev_children [1 ].grad += result .grad ;
68+ }
69+ }
70+ }.call ,
71+ .prev = try allocator .dupe (* Self , &.{ self , other }),
72+ .operation = "+" ,
73+ .label = label ,
74+ };
75+ }
76+
77+ pub fn mul (self : * Self , other : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
78+ const MulBackward = struct {
13279 fn call (result : * Self ) void {
13380 if (result .prev ) | prev_children | {
134- prev_children [0 ].grad += result .grad ;
135- prev_children [1 ].grad -= result .grad ;
81+ prev_children [0 ].grad += prev_children [ 1 ]. data * result .grad ;
82+ prev_children [1 ].grad += prev_children [ 0 ]. data * result .grad ;
13683 }
13784 }
13885 }.call ;
13986
87+ return Self {
88+ .data = self .data * other .data ,
89+ .grad = 0 ,
90+ .backward = MulBackward ,
91+ .prev = try allocator .dupe (* Self , &.{ self , other }),
92+ .operation = "*" ,
93+ .label = label ,
94+ };
95+ }
96+
97+ /// Subtract two values
98+ pub fn sub (self : * Self , other : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
14099 return Self {
141100 .data = self .data - other .data ,
142101 .grad = 0 ,
143- .backprop = SubBackward ,
144- .prev = children ,
145- .operation = Operation .SUB ,
146- .label = null ,
102+ .backward = struct {
103+ fn call (result : * Self ) void {
104+ if (result .prev ) | prev_children | {
105+ prev_children [0 ].grad += result .grad ;
106+ prev_children [1 ].grad -= result .grad ;
107+ }
108+ }
109+ }.call ,
110+ .prev = try allocator .dupe (* Self , &.{ self , other }),
111+ .operation = "-" ,
112+ .label = label ,
147113 };
148114 }
149115
150116 /// Divide two values
151- pub fn div (self : * Self , other : * Self , allocator : std.mem.Allocator ) ! Self {
152- const children = try allocator .dupe (* Self , &.{ self , other });
153- const DivBackward = struct {
154- fn call (result : * Self ) void {
155- if (result .prev ) | prev_children | {
156- prev_children [0 ].grad += result .grad / other .data ;
157- prev_children [1 ].grad -= result .grad * self .data / (other .data * other .data );
117+ pub fn div (self : * Self , other : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
118+ return Self {
119+ .data = self .data / other .data ,
120+ .grad = 0 ,
121+ .backward = struct {
122+ fn call (result : * Self ) void {
123+ if (result .prev ) | prev_children | {
124+ prev_children [0 ].grad += result .grad / other .data ;
125+ prev_children [1 ].grad -= result .grad * self .data / (other .data * other .data );
126+ }
158127 }
159- }
160- }.call ;
128+ }.call ,
129+ .prev = try allocator .dupe (* Self , &.{ self , other }),
130+ .operation = "/" ,
131+ .label = label ,
132+ };
133+ }
161134
135+ pub fn relu (self : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
162136 return Self {
163- .data = self .data / other .data ,
137+ .data = if ( self .data > 0 ) self .data else 0 ,
164138 .grad = 0 ,
165- .backprop = DivBackward ,
166- .prev = children ,
167- .operation = Operation .DIV ,
168- .label = null ,
139+ .backward = struct {
140+ fn call (result : * Self ) void {
141+ if (result .prev ) | prev_children | {
142+ prev_children [0 ].grad += result .grad * (self .data > 0 );
143+ }
144+ }
145+ }.call ,
146+ .prev = try allocator .dupe (* Self , &.{self }),
147+ .operation = "ReLU" ,
148+ .label = label ,
149+ };
150+ }
151+
152+ pub fn softmax (self : * Self , allocator : std.mem.Allocator , label : ? []const u8 ) ! Self {
153+ return Self {
154+ .data = std .math .exp (self .data ),
155+ .grad = 0 ,
156+ .backward = struct {
157+ fn call (result : * Self ) void {
158+ if (result .prev ) | prev_children | {
159+ prev_children [0 ].grad += result .grad ;
160+ }
161+ }
162+ }.call ,
163+ .prev = try allocator .dupe (* Self , &.{self }),
164+ .operation = "Softmax" ,
165+ .label = label ,
169166 };
170167 }
171168
@@ -192,15 +189,17 @@ pub fn Value(comptime T: type) type {
192189 const node_id = @intFromPtr (node );
193190 const label_str = if (node .label ) | label | label else "" ;
194191 const data_str = try std .fmt .allocPrint (allocator , "{d:.4}" , .{node .data });
192+ const grad_str = try std .fmt .allocPrint (allocator , "{d:.4}" , .{node .grad });
195193 defer allocator .free (data_str );
194+ defer allocator .free (grad_str );
196195
197- try writer .print (" \" {}\" [label=\" {{{s} | data {s}}} \" , shape=record];\n " , .{ node_id , label_str , data_str });
196+ try writer .print (" \" {}\" [label=\" {{{s} | data {s} | grad {s}}} \" , shape=record];\n " , .{ node_id , label_str , data_str , grad_str });
198197
199198 // If this value is a result of some operation, create an op node for it
200199 if (node .operation ) | op | {
201200 const op_id = try std .fmt .allocPrint (allocator , "{}op" , .{node_id });
202201 defer allocator .free (op_id );
203- try writer .print (" \" {s}\" [label=\" {s}\" ];\n " , .{ op_id , op . toString () });
202+ try writer .print (" \" {s}\" [label=\" {s}\" ];\n " , .{ op_id , op });
204203 try writer .print (" \" {s}\" -> \" {}\" ;\n " , .{ op_id , node_id });
205204 }
206205 }
0 commit comments