Skip to content

Commit 938a176

Browse files
committed
Cut Value fat
1 parent 957b6db commit 938a176

File tree

3 files changed

+108
-114
lines changed

3 files changed

+108
-114
lines changed

.zigversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.14.0
1+
0.14.1

src/engine.zig

Lines changed: 102 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2,108 +2,39 @@
22

33
const std = @import("std");
44

5-
/// Operations supported by the engine
6-
pub const Operation = enum(u8) {
7-
/// Add two values
8-
ADD,
9-
/// Subtract two values
10-
SUB,
11-
/// Multiply two values
12-
MUL,
13-
/// Divide two values
14-
DIV,
15-
16-
/// Convert the Operation to its mathematical symbol
17-
pub fn toString(self: Operation) []const u8 {
18-
return switch (self) {
19-
.ADD => "+",
20-
.SUB => "-",
21-
.MUL => "*",
22-
.DIV => "/",
23-
};
24-
}
25-
};
26-
27-
pub fn info(T: type) std.builtin.Type.Vector {
28-
if (@typeInfo(T) != .vector) @compileError("Expected a @Vector type got: " ++ @typeName(T));
29-
return @typeInfo(T).vector;
30-
}
31-
325
/// Represents a singular Scalar value
336
pub fn Value(comptime T: type) type {
7+
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
8+
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
9+
}
10+
3411
return struct {
3512
const Self = @This();
3613
/// The value
3714
data: T,
3815
/// The gradient
3916
grad: T,
4017
/// Function for backpropagation
41-
backprop: ?*const fn (self: *Self) void,
18+
backward: ?*const fn (self: *Self) void,
4219
/// The children used to compute the value
4320
prev: ?[]*Self,
4421
/// The operation that produced the value
45-
operation: ?Operation,
22+
operation: ?[]const u8,
4623
/// The label of the value
4724
label: ?[]const u8,
4825

4926
/// Initialize the Value
50-
pub fn init(data: T, prev: ?[]*Self, operation: ?Operation, label: ?[]const u8) Self {
27+
pub fn init(data: T, prev: ?[]*Self, operation: ?[]const u8, label: ?[]const u8) Self {
5128
return Self{
5229
.data = data,
5330
.grad = 0,
54-
.backprop = null,
31+
.backward = null,
5532
.prev = prev,
5633
.operation = operation,
5734
.label = label,
5835
};
5936
}
6037

61-
pub fn add(self: *Self, other: *Self, allocator: std.mem.Allocator) !Self {
62-
const children = try allocator.dupe(*Self, &.{ self, other });
63-
const AddBackward = struct {
64-
fn call(result: *Self) void {
65-
if (result.prev) |prev_children| {
66-
prev_children[0].grad += result.grad;
67-
prev_children[1].grad += result.grad;
68-
}
69-
}
70-
}.call;
71-
72-
return Self{
73-
.data = self.data + other.data,
74-
.grad = 0,
75-
.backprop = AddBackward,
76-
.prev = children,
77-
.operation = Operation.ADD,
78-
.label = null,
79-
};
80-
}
81-
82-
pub fn mul(self: *Self, other: *Self, allocator: std.mem.Allocator) !Self {
83-
const children = try allocator.dupe(*Self, &.{ self, other });
84-
const MulBackward = struct {
85-
fn call(result: *Self) void {
86-
if (result.prev) |prev_children| {
87-
prev_children[0].grad += prev_children[1].data * result.grad;
88-
prev_children[1].grad += prev_children[0].data * result.grad;
89-
}
90-
}
91-
}.call;
92-
93-
return Self{
94-
.data = self.data * other.data,
95-
.grad = 0,
96-
.backprop = MulBackward,
97-
.prev = children,
98-
.operation = Operation.MUL,
99-
.label = null,
100-
};
101-
}
102-
103-
pub fn backward(self: *Self) void {
104-
if (self.backprop) |bp| bp(self);
105-
}
106-
10738
/// Convert the Value to a string
10839
pub fn toString(self: Self) []const u8 {
10940
const op_name = if (self.operation) |op| @tagName(op) else "null";
@@ -125,47 +56,113 @@ pub fn Value(comptime T: type) type {
12556
return std.fmt.allocPrint(std.heap.page_allocator, "Value(data={any}, grad={any}, prev={s}, operation={s}, label={s})", .{ self.data, self.grad, prev_str, op_name, label_name }) catch unreachable;
12657
}
12758

128-
/// Subtract two values
129-
pub fn sub(self: *Self, other: *Self, allocator: std.mem.Allocator) !Self {
130-
const children = try allocator.dupe(*Self, &.{ self, other });
131-
const SubBackward = struct {
59+
pub fn add(self: *Self, other: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
60+
return Self{
61+
.data = self.data + other.data,
62+
.grad = 0,
63+
.backward = struct {
64+
fn call(result: *Self) void {
65+
if (result.prev) |prev_children| {
66+
prev_children[0].grad += result.grad;
67+
prev_children[1].grad += result.grad;
68+
}
69+
}
70+
}.call,
71+
.prev = try allocator.dupe(*Self, &.{ self, other }),
72+
.operation = "+",
73+
.label = label,
74+
};
75+
}
76+
77+
pub fn mul(self: *Self, other: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
78+
const MulBackward = struct {
13279
fn call(result: *Self) void {
13380
if (result.prev) |prev_children| {
134-
prev_children[0].grad += result.grad;
135-
prev_children[1].grad -= result.grad;
81+
prev_children[0].grad += prev_children[1].data * result.grad;
82+
prev_children[1].grad += prev_children[0].data * result.grad;
13683
}
13784
}
13885
}.call;
13986

87+
return Self{
88+
.data = self.data * other.data,
89+
.grad = 0,
90+
.backward = MulBackward,
91+
.prev = try allocator.dupe(*Self, &.{ self, other }),
92+
.operation = "*",
93+
.label = label,
94+
};
95+
}
96+
97+
/// Subtract two values
98+
pub fn sub(self: *Self, other: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
14099
return Self{
141100
.data = self.data - other.data,
142101
.grad = 0,
143-
.backprop = SubBackward,
144-
.prev = children,
145-
.operation = Operation.SUB,
146-
.label = null,
102+
.backward = struct {
103+
fn call(result: *Self) void {
104+
if (result.prev) |prev_children| {
105+
prev_children[0].grad += result.grad;
106+
prev_children[1].grad -= result.grad;
107+
}
108+
}
109+
}.call,
110+
.prev = try allocator.dupe(*Self, &.{ self, other }),
111+
.operation = "-",
112+
.label = label,
147113
};
148114
}
149115

150116
/// Divide two values
151-
pub fn div(self: *Self, other: *Self, allocator: std.mem.Allocator) !Self {
152-
const children = try allocator.dupe(*Self, &.{ self, other });
153-
const DivBackward = struct {
154-
fn call(result: *Self) void {
155-
if (result.prev) |prev_children| {
156-
prev_children[0].grad += result.grad / other.data;
157-
prev_children[1].grad -= result.grad * self.data / (other.data * other.data);
117+
pub fn div(self: *Self, other: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
118+
return Self{
119+
.data = self.data / other.data,
120+
.grad = 0,
121+
.backward = struct {
122+
fn call(result: *Self) void {
123+
if (result.prev) |prev_children| {
124+
prev_children[0].grad += result.grad / other.data;
125+
prev_children[1].grad -= result.grad * self.data / (other.data * other.data);
126+
}
158127
}
159-
}
160-
}.call;
128+
}.call,
129+
.prev = try allocator.dupe(*Self, &.{ self, other }),
130+
.operation = "/",
131+
.label = label,
132+
};
133+
}
161134

135+
pub fn relu(self: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
162136
return Self{
163-
.data = self.data / other.data,
137+
.data = if (self.data > 0) self.data else 0,
164138
.grad = 0,
165-
.backprop = DivBackward,
166-
.prev = children,
167-
.operation = Operation.DIV,
168-
.label = null,
139+
.backward = struct {
140+
fn call(result: *Self) void {
141+
if (result.prev) |prev_children| {
142+
prev_children[0].grad += result.grad * (self.data > 0);
143+
}
144+
}
145+
}.call,
146+
.prev = try allocator.dupe(*Self, &.{self}),
147+
.operation = "ReLU",
148+
.label = label,
149+
};
150+
}
151+
152+
pub fn softmax(self: *Self, allocator: std.mem.Allocator, label: ?[]const u8) !Self {
153+
return Self{
154+
.data = std.math.exp(self.data),
155+
.grad = 0,
156+
.backward = struct {
157+
fn call(result: *Self) void {
158+
if (result.prev) |prev_children| {
159+
prev_children[0].grad += result.grad;
160+
}
161+
}
162+
}.call,
163+
.prev = try allocator.dupe(*Self, &.{self}),
164+
.operation = "Softmax",
165+
.label = label,
169166
};
170167
}
171168

@@ -192,15 +189,17 @@ pub fn Value(comptime T: type) type {
192189
const node_id = @intFromPtr(node);
193190
const label_str = if (node.label) |label| label else "";
194191
const data_str = try std.fmt.allocPrint(allocator, "{d:.4}", .{node.data});
192+
const grad_str = try std.fmt.allocPrint(allocator, "{d:.4}", .{node.grad});
195193
defer allocator.free(data_str);
194+
defer allocator.free(grad_str);
196195

197-
try writer.print(" \"{}\" [label=\"{{{s} | data {s}}}\", shape=record];\n", .{ node_id, label_str, data_str });
196+
try writer.print(" \"{}\" [label=\"{{{s} | data {s} | grad {s}}}\", shape=record];\n", .{ node_id, label_str, data_str, grad_str });
198197

199198
// If this value is a result of some operation, create an op node for it
200199
if (node.operation) |op| {
201200
const op_id = try std.fmt.allocPrint(allocator, "{}op", .{node_id});
202201
defer allocator.free(op_id);
203-
try writer.print(" \"{s}\" [label=\"{s}\"];\n", .{ op_id, op.toString() });
202+
try writer.print(" \"{s}\" [label=\"{s}\"];\n", .{ op_id, op });
204203
try writer.print(" \"{s}\" -> \"{}\";\n", .{ op_id, node_id });
205204
}
206205
}

train.zig

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,18 @@ pub fn main() !void {
1515
var a = lib.engine.Value(f32).init(2.0, null, null, "a");
1616
var b = lib.engine.Value(f32).init(-3.0, null, null, "b");
1717
var c = lib.engine.Value(f32).init(10.0, null, null, "c");
18-
var d = try a.mul(&b, std.heap.page_allocator);
18+
var d = try a.mul(&b, std.heap.page_allocator, "d");
1919
// Perform operations: e = (a * b) + c
20-
var e = try d.add(&c, std.heap.page_allocator);
21-
// // Add another branch: f = a * b
22-
// var f = try a.mul(&b, std.heap.page_allocator);
23-
// f.label = "f";
24-
25-
// // Final result: g = e + f = (a + b) * c + a * b
26-
// var g = try e.add(&f, std.heap.page_allocator);
27-
// g.label = "g";
20+
var e = try d.add(&c, std.heap.page_allocator, "e");
21+
var f = lib.engine.Value(f32).init(-2.0, null, null, "f");
22+
var g = try f.mul(&e, std.heap.page_allocator, "g");
2823

2924
// Write the computational graph to a Graphviz file
3025
const file = try std.fs.cwd().createFile("graph.dot", .{});
3126
defer file.close();
3227

3328
const file_writer = file.writer();
34-
try e.draw_dot(file_writer, std.heap.page_allocator);
29+
try g.draw_dot(file_writer, std.heap.page_allocator);
3530

3631
try stdout.print("Computational graph written to graph.dot\n", .{});
3732
try stdout.print("You can visualize it by running: dot -Tpng graph.dot -o graph.png\n", .{});

0 commit comments

Comments
 (0)