|
3 | 3 | const std = @import("std"); |
4 | 4 |
|
5 | 5 | /// Represents an auto-differentiable Scalar value |
| 6 | +/// |
| 7 | +/// This is a generic type that can be used to create a scalar-valued value. |
| 8 | +/// |
| 9 | +/// # Example |
| 10 | +/// ```zig |
| 11 | +/// const Value = @import("engine").Value; |
| 12 | +/// const value = Value(f32).new(2.0); |
| 13 | +/// ``` |
| 14 | +/// |
| 15 | +/// # Operations |
| 16 | +/// |
| 17 | +/// The Value type supports the following operations: |
| 18 | +/// |
| 19 | +/// - Addition |
6 | 20 | pub fn Value(comptime T: type) type { |
7 | 21 | if (@typeInfo(T) != .int and @typeInfo(T) != .float) { |
8 | 22 | @compileError("Expected @int or @float type, got: " ++ @typeName(T)); |
@@ -235,5 +249,45 @@ pub fn Value(comptime T: type) type { |
235 | 249 | } |
236 | 250 | } |
237 | 251 | } |
| 252 | + |
| 253 | + /// Build a topological ordering of the computational graph using Depth-First Search (DFS) |
| 254 | + fn buildTopo(self: *Self, topo: *std.ArrayList(*Self), visited: *std.AutoHashMap(*Self, void)) !void { |
| 255 | + if (visited.contains(self)) { |
| 256 | + return; |
| 257 | + } |
| 258 | + |
| 259 | + try visited.put(self, {}); |
| 260 | + |
| 261 | + if (self.prev) |children| { |
| 262 | + for (children) |child| { |
| 263 | + try child.buildTopo(topo, visited); |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + try topo.append(self); |
| 268 | + } |
| 269 | + |
| 270 | + /// Backward pass - topological sort and gradient computation |
| 271 | + pub fn backwardPass(self: *Self, allocator: std.mem.Allocator) !void { |
| 272 | + // Topological ordering |
| 273 | + var topo = std.ArrayList(*Self).init(allocator); |
| 274 | + defer topo.deinit(); |
| 275 | + |
| 276 | + var visited = std.AutoHashMap(*Self, void).init(allocator); |
| 277 | + defer visited.deinit(); |
| 278 | + |
| 279 | + try self.buildTopo(&topo, &visited); |
| 280 | + |
| 281 | + // Apply chain rule |
| 282 | + self.grad = @as(T, 1); |
| 283 | + |
| 284 | + // Reverse the topo list and call backward on each node |
| 285 | + const items = topo.items; |
| 286 | + var i = items.len; |
| 287 | + while (i > 0) { |
| 288 | + i -= 1; |
| 289 | + items[i].backward(items[i]); |
| 290 | + } |
| 291 | + } |
238 | 292 | }; |
239 | 293 | } |
0 commit comments