@@ -37,20 +37,8 @@ pub fn Scalar(comptime T: type) type {
3737
3838 const Expr = union (engine .ExprType ) {
3939 nop : void ,
40- unary : struct {
41- /// The unary operation that produced the value
42- op : engine.UnaryType ,
43- backprop_fn : BackpropFn ,
44- /// The children used to compute the value
45- prev : [1 ]* Self ,
46- },
47- binary : struct {
48- /// The binary operation that produced the value
49- op : engine.BinaryType ,
50- backprop_fn : BackpropFn ,
51- /// The children used to compute the value
52- prev : [2 ]* Self ,
53- },
40+ unary : engine .UnaryOp (Self ),
41+ binary : engine .BinaryOp (Self ),
5442 };
5543
5644 /// The value
@@ -177,7 +165,7 @@ pub fn Scalar(comptime T: type) type {
177165
178166 /// Backpropagation function for ReLU
179167 fn relu_back (self : * Self ) void {
180- self .expr .unary .prev [0 ].grad += if (self .data > 0 ) self .grad else @as (T , 0 );
168+ self .expr .unary .prev [0 ].grad += if (self .expr . unary . prev [ 0 ]. data > 0 ) self .grad else @as (T , 0 );
181169 }
182170
183171 /// Apply the softmax function to a Scalar
@@ -303,6 +291,14 @@ pub fn Scalar(comptime T: type) type {
303291 // Apply chain rule
304292 self .grad = @as (T , 1 );
305293
294+ // Zero gradients only on intermediate computation nodes (not leaf nodes, not the loss node)
295+ // Leaf nodes (parameters/inputs with .nop) preserve their gradients to allow accumulation
296+ for (topo .items ) | node | {
297+ if (node .expr != .nop and node != self ) {
298+ node .grad = @as (T , 0 );
299+ }
300+ }
301+
306302 // Reverse the topo list and call backward on each node
307303 const items = topo .items ;
308304 var i = items .len ;
@@ -324,8 +320,8 @@ pub fn Scalar(comptime T: type) type {
324320 const file_writer = file .writer ();
325321 graph .draw_dot (file_writer , std .heap .page_allocator ) catch unreachable ;
326322
327- std .debug .print ("Computational graph written to {s}\n " , .{dot_name });
328- std .debug .print ("You can visualize it by running: dot -Tpng {s} -o {s}\n " , .{ dot_name , png_name });
323+ // std.debug.print("Computational graph written to {s}\n", .{dot_name});
324+ // std.debug.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name });
329325 }
330326 };
331327}
0 commit comments