Skip to content

Commit 978f0db

Browse files
committed
UnaryOp and BinaryOp
1 parent fd4e0ce commit 978f0db

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

src/engine/engine.zig

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@ pub const UnaryType = enum {
2323
}
2424
};
2525

26-
// pub const UnaryOp = struct {
27-
// type: UnaryType,
28-
// func: fn (Scalar) Scalar,
29-
// };
26+
/// Unary operation structure
27+
pub fn UnaryOp(comptime ValueType: type) type {
28+
return struct {
29+
/// The unary operation that produced the value
30+
op: UnaryType,
31+
/// The backpropagation function
32+
backprop_fn: *const fn (*ValueType) void,
33+
/// The children used to compute the value
34+
prev: [1]*ValueType,
35+
};
36+
}
3037

3138
pub const BinaryType = enum {
3239
add,
@@ -44,6 +51,18 @@ pub const BinaryType = enum {
4451
}
4552
};
4653

54+
/// Binary operation structure
55+
pub fn BinaryOp(comptime ValueType: type) type {
56+
return struct {
57+
/// The binary operation that produced the value
58+
op: BinaryType,
59+
/// The backpropagation function
60+
backprop_fn: *const fn (*ValueType) void,
61+
/// The children used to compute the value
62+
prev: [2]*ValueType,
63+
};
64+
}
65+
4766
pub const Scalar = @import("scalar.zig").Scalar;
4867
pub const Array = @import("tensor.zig").Array;
4968
pub const Tensor = @import("tensor.zig").Tensor;

src/engine/scalar.zig

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,8 @@ pub fn Scalar(comptime T: type) type {
3737

3838
const Expr = union(engine.ExprType) {
3939
nop: void,
40-
unary: struct {
41-
/// The unary operation that produced the value
42-
op: engine.UnaryType,
43-
backprop_fn: BackpropFn,
44-
/// The children used to compute the value
45-
prev: [1]*Self,
46-
},
47-
binary: struct {
48-
/// The binary operation that produced the value
49-
op: engine.BinaryType,
50-
backprop_fn: BackpropFn,
51-
/// The children used to compute the value
52-
prev: [2]*Self,
53-
},
40+
unary: engine.UnaryOp(Self),
41+
binary: engine.BinaryOp(Self),
5442
};
5543

5644
/// The value
@@ -177,7 +165,7 @@ pub fn Scalar(comptime T: type) type {
177165

178166
/// Backpropagation function for ReLU
179167
fn relu_back(self: *Self) void {
180-
self.expr.unary.prev[0].grad += if (self.data > 0) self.grad else @as(T, 0);
168+
self.expr.unary.prev[0].grad += if (self.expr.unary.prev[0].data > 0) self.grad else @as(T, 0);
181169
}
182170

183171
/// Apply the softmax function to a Scalar
@@ -303,6 +291,14 @@ pub fn Scalar(comptime T: type) type {
303291
// Apply chain rule
304292
self.grad = @as(T, 1);
305293

294+
// Zero gradients only on intermediate computation nodes (not leaf nodes, not the loss node)
295+
// Leaf nodes (parameters/inputs with .nop) preserve their gradients to allow accumulation
296+
for (topo.items) |node| {
297+
if (node.expr != .nop and node != self) {
298+
node.grad = @as(T, 0);
299+
}
300+
}
301+
306302
// Reverse the topo list and call backward on each node
307303
const items = topo.items;
308304
var i = items.len;
@@ -324,8 +320,8 @@ pub fn Scalar(comptime T: type) type {
324320
const file_writer = file.writer();
325321
graph.draw_dot(file_writer, std.heap.page_allocator) catch unreachable;
326322

327-
std.debug.print("Computational graph written to {s}\n", .{dot_name});
328-
std.debug.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name });
323+
// std.debug.print("Computational graph written to {s}\n", .{dot_name});
324+
// std.debug.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name });
329325
}
330326
};
331327
}

0 commit comments

Comments
 (0)