Skip to content

Commit 5d9e246

Browse files
committed
Document Value
1 parent 54314a9 commit 5d9e246

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/engine.zig

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
const std = @import("std");
44

55
/// Represents an auto-differentiable Scalar value
6+
///
7+
/// This is a generic type that can be used to create a scalar-valued value.
8+
///
9+
/// # Example
10+
/// ```zig
11+
/// const Value = @import("engine").Value;
12+
/// const value = Value(f32).new(2.0);
13+
/// ```
14+
///
15+
/// # Operations
16+
///
17+
/// The Value type supports the following operations:
18+
///
19+
/// - Addition
620
pub fn Value(comptime T: type) type {
721
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
822
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
@@ -235,5 +249,45 @@ pub fn Value(comptime T: type) type {
235249
}
236250
}
237251
}
252+
253+
/// Build a topological ordering of the computational graph using Depth-First Search (DFS)
254+
fn buildTopo(self: *Self, topo: *std.ArrayList(*Self), visited: *std.AutoHashMap(*Self, void)) !void {
255+
if (visited.contains(self)) {
256+
return;
257+
}
258+
259+
try visited.put(self, {});
260+
261+
if (self.prev) |children| {
262+
for (children) |child| {
263+
try child.buildTopo(topo, visited);
264+
}
265+
}
266+
267+
try topo.append(self);
268+
}
269+
270+
/// Backward pass - topological sort and gradient computation
271+
pub fn backwardPass(self: *Self, allocator: std.mem.Allocator) !void {
272+
// Topological ordering
273+
var topo = std.ArrayList(*Self).init(allocator);
274+
defer topo.deinit();
275+
276+
var visited = std.AutoHashMap(*Self, void).init(allocator);
277+
defer visited.deinit();
278+
279+
try self.buildTopo(&topo, &visited);
280+
281+
// Apply chain rule
282+
self.grad = @as(T, 1);
283+
284+
// Reverse the topo list and call backward on each node
285+
const items = topo.items;
286+
var i = items.len;
287+
while (i > 0) {
288+
i -= 1;
289+
items[i].backward(items[i]);
290+
}
291+
}
238292
};
239293
}

0 commit comments

Comments
 (0)