Skip to content

Commit 2cee3d9

Browse files
committed
MLP
1 parent 73be3d4 commit 2cee3d9

File tree

9 files changed

+347
-54
lines changed

9 files changed

+347
-54
lines changed

.justfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ build:
2222
run:
2323
@echo "Running..."
2424
@zig build run -Doptimize=ReleaseFast
25+
@dot -Tpng assets/img/mlp.dot -o assets/img/mlp.png
26+
@dot -Tpng assets/img/perceptron.dot -o assets/img/perceptron.png
2527

2628
# Test the project
2729
test:

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@ A scalar-valued automatic differentiation (autograd) engine for deep learning wr
3535
## Preview
3636

3737
<p align="center">
38-
<img src="assets/img/train.png"
38+
<img src="assets/img/mlp.png"
3939
width = "80%"
40-
alt = "Training a model on MNIST dataset"
40+
alt = "MLP"
41+
/>
42+
<img src="assets/img/perceptron.png"
43+
width = "80%"
44+
alt = "Perceptron"
4145
/>
4246
</p>
4347

assets/img/mlp.png

20.5 KB
Loading

assets/img/perceptron.png

167 KB
Loading

assets/img/train.png

-64.4 KB
Binary file not shown.

examples/train.zig

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ const std = @import("std");
66
const kiwigrad = @import("kiwigrad");
77
const zbench = @import("zbench");
88

9-
const print = std.debug.print;
10-
119
pub fn main() !void {
1210
const stdout_file = std.io.getStdOut().writer();
1311
var bw = std.io.bufferedWriter(stdout_file);
@@ -19,43 +17,53 @@ pub fn main() !void {
1917
const ValueType = kiwigrad.engine.Value(f64);
2018
const NeuronType = kiwigrad.nn.Neuron(f64);
2119
const LayerType = kiwigrad.nn.Layer(f64);
22-
// const MLPType = kiwigrad.nn.MLP;
20+
const MLPType = kiwigrad.nn.MLP(f64);
2321

2422
// Initialize allocators and components
2523
ValueType.init(alloc);
2624
NeuronType.init(alloc);
2725
LayerType.init(alloc);
26+
MLPType.init(alloc);
2827
defer {
2928
ValueType.deinit();
3029
NeuronType.deinit();
3130
LayerType.deinit();
32-
// MLPType.deinit();
31+
MLPType.deinit();
3332
}
3433

34+
var sizes = [_]usize{ 3, 2, 1 };
35+
3536
// Initialize the neuron
36-
const neuron = NeuronType.new(3);
37+
const mlp = MLPType.new(sizes.len - 1, sizes[0..]);
3738

38-
// Create sample input data
39-
var input_data = [_]*ValueType{
40-
ValueType.new(1.0),
41-
ValueType.new(2.0),
42-
ValueType.new(3.0),
39+
const inputs = [_][3]*ValueType{
40+
[_]*ValueType{ ValueType.new(2), ValueType.new(3), ValueType.new(-1) },
41+
[_]*ValueType{ ValueType.new(3), ValueType.new(-1), ValueType.new(0.5) },
42+
[_]*ValueType{ ValueType.new(0.5), ValueType.new(1), ValueType.new(1) },
43+
[_]*ValueType{ ValueType.new(1), ValueType.new(2), ValueType.new(3) },
4344
};
4445

45-
// Forward pass through the layer
46-
const output = neuron.forward(input_data[0..]);
46+
mlp.draw_graph("assets/img/mlp", stdout);
47+
48+
for (inputs) |in| {
49+
// Forward pass through the layer
50+
const output = mlp.forward(@constCast(&in));
51+
stdout.print("{d:7.4} ", .{output[0].data}) catch unreachable;
52+
for (output) |o| {
53+
_ = o.draw_graph("assets/img/perceptron", stdout);
54+
}
55+
}
4756

48-
// outputs now contains 2 ValueType pointers (one for each neuron)
49-
print("Layer output: {d:.4}\n", .{output.data});
57+
// // outputs now contains 2 ValueType pointers (one for each neuron)
58+
// print("Layer output: {d:.4}\n", .{output.data});
5059

51-
print("output.data: {d:.4}\n", .{output.data});
52-
print("output.grad: {d:.4}\n", .{output.grad});
60+
// print("output.data: {d:.4}\n", .{output.data});
61+
// print("output.grad: {d:.4}\n", .{output.grad});
5362

54-
output.backwardPass(alloc);
63+
// output.backwardPass(alloc);
5564

56-
print("output.data: {d:.4}\n", .{output.data});
57-
print("output.grad: {d:.4}\n", .{output.grad});
65+
// print("output.data: {d:.4}\n", .{output.data});
66+
// print("output.grad: {d:.4}\n", .{output.grad});
5867

59-
output.draw_graph("assets/img/train", stdout);
6068
try bw.flush(); // Don't forget to flush!
6169
}

src/engine.zig

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub const UnaryType = enum {
1818
pub fn toString(self: UnaryType) []const u8 {
1919
return switch (self) {
2020
.tanh => "tanh",
21-
.exp => "exp",
21+
.exp => "^",
2222
.relu => "ReLU",
2323
.softmax => "Softmax",
2424
};
@@ -58,6 +58,7 @@ pub const BinaryType = enum {
5858
/// - Addition
5959
/// - Subtraction
6060
/// - Multiplication
61+
/// - Exponentiation
6162
/// - Division
6263
/// - Rectified Linear Unit (ReLU)
6364
/// - Softmax
@@ -175,6 +176,16 @@ pub fn Value(comptime T: type) type {
175176
self.expr.binary.prev[1].grad += self.grad * self.expr.binary.prev[0].data;
176177
}
177178

179+
/// Exponentiate a value
180+
pub inline fn exp(self: *Self) *Self {
181+
return unary(std.math.exp(self.data), .exp, exp_back, self);
182+
}
183+
184+
/// Backpropagation function for exponentiation
185+
fn exp_back(self: *Self) void {
186+
self.expr.unary.prev[0].grad += self.grad * std.math.exp(self.data);
187+
}
188+
178189
/// Subtract two values
179190
pub inline fn sub(self: *Self, other: *Self) *Self {
180191
return binary(self.data - other.data, .sub, sub_back, self, other);

0 commit comments

Comments
 (0)