Skip to content

Commit 00f0613

Browse files
committed
Jump tables > branches
1 parent 43f6feb commit 00f0613

File tree

5 files changed

+183
-3
lines changed

5 files changed

+183
-3
lines changed

assets/img/perceptron.png

-1.02 KB
Loading

examples/train.zig

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,23 @@ pub fn main() !void {
1414
const NeuronType = kiwigrad.nn.Neuron(f64);
1515
const LayerType = kiwigrad.nn.Layer(f64);
1616
const MLPType = kiwigrad.nn.MLP(f64);
17+
const ArrayType = kiwigrad.engine.Array(f64);
18+
const TensorType = kiwigrad.engine.Tensor(f64);
1719

1820
// Initialize allocators and components
1921
ValueType.init(alloc);
2022
NeuronType.init(alloc);
2123
LayerType.init(alloc);
2224
MLPType.init(alloc);
25+
ArrayType.init(alloc);
26+
TensorType.init(alloc);
2327
defer {
2428
ValueType.deinit();
2529
NeuronType.deinit();
2630
LayerType.deinit();
2731
MLPType.deinit();
32+
ArrayType.deinit();
33+
TensorType.deinit();
2834
}
2935

3036
var sizes = [_]usize{ 3, 2, 1 };
@@ -50,6 +56,9 @@ pub fn main() !void {
5056
}
5157
}
5258

59+
const t1 = TensorType.new(&[_]f64{ 1, 2, 3, 4 });
60+
std.debug.print("t1: {d:.4}\n", .{t1.data[0].data});
61+
5362
// // outputs now contains 2 ValueType pointers (one for each neuron)
5463
// print("Layer output: {d:.4}\n", .{output.data});
5564

src/engine/engine.zig

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ pub const UnaryType = enum {
2323
}
2424
};
2525

26+
// pub const UnaryOp = struct {
27+
// type: UnaryType,
28+
// func: fn (Scalar) Scalar,
29+
// };
30+
2631
pub const BinaryType = enum {
2732
add,
2833
sub,
@@ -40,4 +45,5 @@ pub const BinaryType = enum {
4045
};
4146

4247
pub const Scalar = @import("scalar.zig").Scalar;
43-
// pub const Tensor = @import("tensor.zig").Tensor;
48+
pub const Array = @import("tensor.zig").Array;
49+
pub const Tensor = @import("tensor.zig").Tensor;

src/engine/scalar.zig

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ const engine = @import("engine.zig");
2626
/// - Softmax
2727
pub fn Scalar(comptime T: type) type {
2828
// Check that T is a valid type
29-
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
30-
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
29+
switch (@typeInfo(T)) {
30+
.int, .comptime_int, .float, .comptime_float => {},
31+
else => @compileError("Expected @int or @float type, got: " ++ @typeName(T)),
3132
}
3233

3334
return struct {

src/engine/tensor.zig

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//! This file provides the autograd engine functionality for kiwigrad
2+
3+
const std = @import("std");
4+
const engine = @import("engine.zig");
5+
6+
/// Represents a multi-dimensional array
7+
pub fn Array(comptime T: type) type {
8+
return struct {
9+
const Self = @This();
10+
11+
/// The data
12+
data: []T,
13+
/// The shape of the array
14+
shape: []usize,
15+
/// The stride of the array
16+
stride: []usize,
17+
/// The number of dimensions of the array
18+
dims: usize,
19+
/// The number of elements in the array
20+
size: usize,
21+
22+
var arena: std.heap.ArenaAllocator = undefined;
23+
24+
pub fn init(alloc: std.mem.Allocator) void {
25+
arena = std.heap.ArenaAllocator.init(alloc);
26+
}
27+
28+
/// Cleanup allocated memory
29+
pub fn deinit() void {
30+
arena.deinit();
31+
}
32+
33+
/// Create a new Array
34+
pub fn new(data: []T, shape: []usize, stride: []usize, dims: usize, size: usize) *Self {
35+
const a = arena.allocator().create(Self) catch unreachable;
36+
a.* = Self{
37+
.data = data,
38+
.shape = shape,
39+
.stride = stride,
40+
.dims = dims,
41+
.size = size,
42+
};
43+
return a;
44+
}
45+
46+
/// Find the element at the given coordinates
47+
pub inline fn at(self: *Self, coords: []const usize) *T {
48+
return self.data[self.index(coords)];
49+
}
50+
51+
/// Find the index of the element at the given coordinates
52+
pub inline fn index(self: *Self, coords: []const usize) usize {
53+
if (coords.len != self.dims) {
54+
std.debug.panic("Input size mismatch: {d} != {d}", .{ coords.len, self.dims });
55+
}
56+
57+
var idx = 0;
58+
for (coords, 0..) |coord, i| {
59+
idx += coord * self.stride[i];
60+
}
61+
return idx;
62+
}
63+
64+
/// Set the element at the given coordinates
65+
pub inline fn set(self: *Self, coords: []const usize, value: T) void {
66+
self.data[self.index(coords)] = value;
67+
}
68+
};
69+
}
70+
71+
/// Represents an auto-differentiable Tensor value
72+
pub fn Tensor(comptime T: type) type {
73+
const ArrayType = Array(T);
74+
// Check that T is a valid type
75+
switch (@typeInfo(T)) {
76+
.int, .comptime_int, .float, .comptime_float => {},
77+
else => @compileError("Expected @int or @float type, got: " ++ @typeName(T)),
78+
}
79+
80+
return struct {
81+
const Self = @This();
82+
const BackpropFn = *const fn (self: *Self) void;
83+
84+
const Expr = union(engine.ExprType) {
85+
nop: void,
86+
unary: struct {
87+
/// The unary operation that produced the value
88+
op: engine.UnaryType,
89+
backprop_fn: BackpropFn,
90+
/// The children used to compute the value
91+
prev: [1]*Self,
92+
},
93+
binary: struct {
94+
/// The binary operation that produced the value
95+
op: engine.BinaryType,
96+
backprop_fn: BackpropFn,
97+
/// The children used to compute the value
98+
prev: [2]*Self,
99+
},
100+
};
101+
102+
/// The data
103+
data: []Array(T),
104+
/// The gradient
105+
grad: []Array(T),
106+
/// The expression that produced the value
107+
expr: Expr,
108+
109+
/// The arena allocator
110+
var arena: std.heap.ArenaAllocator = undefined;
111+
112+
/// Initialize the arena allocator
113+
pub fn init(alloc: std.mem.Allocator) void {
114+
arena = std.heap.ArenaAllocator.init(alloc);
115+
}
116+
117+
/// Deinitialize the arena allocator
118+
pub fn deinit() void {
119+
arena.deinit();
120+
}
121+
122+
/// Create a new Tensor value from array data
123+
pub fn new(data: []const T) *Self {
124+
const t = arena.allocator().create(Self) catch unreachable;
125+
126+
// Copy the input data to our own allocation
127+
const tensor_data = arena.allocator().alloc(T, data.len) catch unreachable;
128+
@memcpy(tensor_data, data);
129+
130+
// Create shape, stride for 1D tensor
131+
const shape = arena.allocator().alloc(usize, 1) catch unreachable;
132+
const stride = arena.allocator().alloc(usize, 1) catch unreachable;
133+
shape[0] = data.len;
134+
stride[0] = 1;
135+
136+
// Create the data array
137+
const data_array = ArrayType.new(tensor_data, shape, stride, 1, data.len);
138+
139+
// Create gradient array (initialized to zeros)
140+
const grad_data = arena.allocator().alloc(T, data.len) catch unreachable;
141+
@memset(grad_data, 0);
142+
const grad_shape = arena.allocator().alloc(usize, 1) catch unreachable;
143+
const grad_stride = arena.allocator().alloc(usize, 1) catch unreachable;
144+
grad_shape[0] = data.len;
145+
grad_stride[0] = 1;
146+
const grad_array = ArrayType.new(grad_data, grad_shape, grad_stride, 1, data.len);
147+
148+
// Create arrays to hold the Array values (not pointers)
149+
const data_arrays = arena.allocator().alloc(ArrayType, 1) catch unreachable;
150+
const grad_arrays = arena.allocator().alloc(ArrayType, 1) catch unreachable;
151+
data_arrays[0] = data_array.*;
152+
grad_arrays[0] = grad_array.*;
153+
154+
t.* = Self{ .data = data_arrays, .grad = grad_arrays, .expr = .{ .nop = {} } };
155+
156+
return t;
157+
}
158+
159+
// /// Add two Tensors
160+
// pub inline fn add(self: *Self, other: *Self) *Self {
161+
// return binary(self.data + other.data, .add, add_back, self, other);
162+
// }
163+
};
164+
}

0 commit comments

Comments
 (0)