Skip to content

Commit a335d73

Browse files
committed
Fix compile issues
1 parent 3455460 commit a335d73

File tree

5 files changed

+44
-28
lines changed

5 files changed

+44
-28
lines changed

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,27 @@ To include `micrograd` in your Zig project, follow these steps:
112112

113113
## Usage
114114

115-
micrograd is designed to be easy to use. You can include the library in your Zig project by adding the following line to your source files:
115+
`micrograd` is designed to be easy to use. You can include the library in your Zig project by adding the following line to your source files:
116116

117117
```zig
118118
const micrograd = @import("micrograd");
119119
```
120120

121+
<!-- PROJECT FILE STRUCTURE -->
122+
## Project Structure
123+
124+
```sh
125+
micrograd/
126+
├── .github/ # GitHub Actions CI/CD workflows
127+
├── src/ # Library source files
128+
│ ├── lib.zig # Public API entry point
129+
│ └── ...
130+
├── build.zig # Zig build script
131+
├── build.zig.zon # Zig build script dependencies
132+
├── LICENSE # Project license
133+
└── README.md # You are here
134+
```
135+
121136
## License
122137

123138
The source code for [Kaweees/micrograd](https://github.com/Kaweees/micrograd) is distributed under the terms of the MIT License, as I firmly believe that collaborating on free and open-source software fosters innovations that mutually and equitably beneficial to both collaborators and users alike. See [`LICENSE`](./LICENSE) for details and more information.

shell.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pkgs.mkShell {
55
zig # Zig compiler
66
just # Just runner
77
nixfmt # Nix formatter
8+
graphviz # Graphviz
89
];
910

1011
# Shell hook to set up environment

src/engine.zig

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub const BinaryType = enum {
6262
/// - Rectified Linear Unit (ReLU)
6363
/// - Softmax
6464
pub fn Value(comptime T: type) type {
65+
// Check that T is a valid type
6566
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
6667
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
6768
}
@@ -153,7 +154,7 @@ pub fn Value(comptime T: type) type {
153154
}
154155

155156
/// Add two values
156-
pub fn add(self: *Self, other: *Self) *Self {
157+
pub inline fn add(self: *Self, other: *Self) *Self {
157158
return binary(self.data + other.data, .add, add_back, self, other);
158159
}
159160

@@ -164,7 +165,7 @@ pub fn Value(comptime T: type) type {
164165
}
165166

166167
/// Multiply two values
167-
pub fn mul(self: *Self, other: *Self) *Self {
168+
pub inline fn mul(self: *Self, other: *Self) *Self {
168169
return binary(self.data * other.data, .mul, mul_back, self, other);
169170
}
170171

@@ -175,7 +176,7 @@ pub fn Value(comptime T: type) type {
175176
}
176177

177178
/// Subtract two values
178-
pub fn sub(self: *Self, other: *Self) *Self {
179+
pub inline fn sub(self: *Self, other: *Self) *Self {
179180
return binary(self.data - other.data, .sub, sub_back, self, other);
180181
}
181182

@@ -186,7 +187,7 @@ pub fn Value(comptime T: type) type {
186187
}
187188

188189
/// Divide two values
189-
pub fn div(self: *Self, other: *Self) *Self {
190+
pub inline fn div(self: *Self, other: *Self) *Self {
190191
return binary(self.data / other.data, .div, div_back, self, other);
191192
}
192193

@@ -197,23 +198,23 @@ pub fn Value(comptime T: type) type {
197198
}
198199

199200
/// Apply the ReLU function to a value
200-
pub fn relu(self: *Self) *Self {
201+
pub inline fn relu(self: *Self) *Self {
201202
return unary(if (self.data > 0) self.data else @as(T, 0), .relu, relu_back, self);
202203
}
203204

204205
/// Backpropagation function for ReLU
205206
fn relu_back(self: *Self) void {
206-
self.expr.unary.prev.grad += if (self.data > 0) self.grad else @as(T, 0);
207+
self.expr.unary.prev[0].grad += if (self.data > 0) self.grad else @as(T, 0);
207208
}
208209

209210
/// Apply the softmax function to a value
210-
pub fn softmax(self: *Self) *Self {
211+
pub inline fn softmax(self: *Self) *Self {
211212
return unary(std.math.exp(self.data), .softmax, softmax_back, self);
212213
}
213214

214215
/// Backpropagation function for softmax
215216
fn softmax_back(self: *Self) void {
216-
self.expr.unary.prev.grad += self.grad * std.math.exp(self.data);
217+
self.expr.unary.prev[0].grad += self.grad * std.math.exp(self.data);
217218
}
218219

219220
/// Generate Graphviz DOT format representation of the computational graph
@@ -302,25 +303,29 @@ pub fn Value(comptime T: type) type {
302303

303304
try visited.put(self, {});
304305

305-
if (self.prev) |children| {
306-
for (children) |child| {
307-
try child.buildTopo(topo, visited);
308-
}
306+
const prevNodes = switch (self.expr) {
307+
.nop => &[_]*Self{},
308+
.unary => |u| &u.prev,
309+
.binary => |b| &b.prev,
310+
};
311+
312+
for (prevNodes) |prev| {
313+
try prev.buildTopo(topo, visited);
309314
}
310315

311316
try topo.append(self);
312317
}
313318

314319
/// Backward pass - topological sort and gradient computation
315-
pub fn backwardPass(self: *Self, allocator: std.mem.Allocator) !void {
320+
pub fn backwardPass(self: *Self, allocator: std.mem.Allocator) void {
316321
// Topological ordering
317322
var topo = std.ArrayList(*Self).init(allocator);
318323
defer topo.deinit();
319324

320325
var visited = std.AutoHashMap(*Self, void).init(allocator);
321326
defer visited.deinit();
322327

323-
try self.buildTopo(&topo, &visited);
328+
self.buildTopo(&topo, &visited) catch unreachable;
324329

325330
// Apply chain rule
326331
self.grad = @as(T, 1);
@@ -330,7 +335,7 @@ pub fn Value(comptime T: type) type {
330335
var i = items.len;
331336
while (i > 0) {
332337
i -= 1;
333-
items[i].backward(items[i]);
338+
items[i].backprop();
334339
}
335340
}
336341
};

src/nn.zig

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,9 @@ const zprob = @import("zprob");
1515
/// const output = try neuron.forward(&inputs);
1616
/// ```
1717
pub fn Neuron(comptime T: type) type {
18-
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
19-
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
20-
}
21-
18+
const ValueType = engine.Value(T);
2219
return struct {
2320
const Self = @This();
24-
const ValueType = engine.Value(T);
2521

2622
/// The number of inputs
2723
nin: usize,
@@ -33,13 +29,12 @@ pub fn Neuron(comptime T: type) type {
3329
var arena: std.heap.ArenaAllocator = undefined;
3430
var env: zprob.RandomEnvironment = undefined;
3531

36-
pub fn init(alloc: std.mem.Allocator) !void {
32+
pub fn init(alloc: std.mem.Allocator) void {
3733
arena = std.heap.ArenaAllocator.init(alloc);
3834
env = try zprob.RandomEnvironment.init(arena.allocator());
39-
defer env.deinit();
4035
}
4136

42-
/// Free allocated memory
37+
/// Cleanup allocated memory
4338
pub fn deinit() void {
4439
arena.deinit();
4540
}
@@ -63,20 +58,20 @@ pub fn Neuron(comptime T: type) type {
6358

6459
/// Generate a random value appropriate for the type T
6560
pub fn generate() T {
66-
return env.rNormal(@as(T, 0), @as(T, 1));
61+
return env.rNormal(@as(T, -1), @as(T, 1)) catch @as(T, 0);
6762
}
6863

6964
/// Forward pass through the neuron
7065
pub fn forward(self: *Self, inputs: []*ValueType) *ValueType {
7166
if (inputs.len != self.nin) {
72-
return error.InputSizeMismatch;
67+
std.debug.panic("Input size mismatch: {d} != {d}", .{ inputs.len, self.nin });
7368
}
7469

7570
var sum = self.bias;
7671
for (self.weights, inputs) |w, x| {
7772
sum = sum.add(w.mul(x));
7873
}
79-
74+
// Apply activation function (ReLU)
8075
return sum.relu();
8176
}
8277
};

train.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub fn main() !void {
4444
try stdout.print("d.data: {}\n", .{d.data});
4545
try stdout.print("d.grad: {}\n", .{d.grad});
4646

47-
d.backprop(); // d.grad = 1
47+
try d.backwardPass(std.heap.page_allocator);
4848

4949
try stdout.print("d.data: {}\n", .{d.data});
5050
try stdout.print("d.grad: {}\n", .{d.grad});

0 commit comments

Comments
 (0)