Skip to content

Commit 9f5eca0

Browse files
committed
Value -> Scalar
1 parent 205392f commit 9f5eca0

File tree

4 files changed

+31
-31
lines changed

4 files changed

+31
-31
lines changed

examples/train.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub fn main() !void {
1414
const alloc = std.heap.page_allocator;
1515

1616
// Initialize the required components
17-
const ValueType = kiwigrad.engine.Value(f64);
17+
const ValueType = kiwigrad.engine.Scalar(f64);
1818
const NeuronType = kiwigrad.nn.Neuron(f64);
1919
const LayerType = kiwigrad.nn.Layer(f64);
2020
const MLPType = kiwigrad.nn.MLP(f64);

src/engine/engine.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ pub const BinaryType = enum {
3939
}
4040
};
4141

42-
pub const Value = @import("value.zig").Value;
42+
pub const Scalar = @import("scalar.zig").Scalar;
43+
// pub const Tensor = @import("tensor.zig").Tensor;
Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ const engine = @import("engine.zig");
99
///
1010
/// # Example
1111
/// ```zig
12-
/// const Value = @import("engine").Value;
13-
/// const value = Value(f32).new(2.0);
12+
/// const Scalar = @import("engine").Scalar;
13+
/// const Scalar = Scalar(f32).new(2.0);
1414
/// ```
1515
///
1616
/// # Operations
1717
///
18-
/// The Value type supports the following operations:
18+
/// The Scalar type supports the following operations:
1919
///
2020
/// - Addition
2121
/// - Subtraction
@@ -24,16 +24,15 @@ const engine = @import("engine.zig");
2424
/// - Division
2525
/// - Rectified Linear Unit (ReLU)
2626
/// - Softmax
27-
pub fn Value(comptime T: type) type {
27+
pub fn Scalar(comptime T: type) type {
2828
// Check that T is a valid type
2929
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
3030
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
3131
}
3232

3333
return struct {
3434
const Self = @This();
35-
const BackpropFn = *const fn (*Self) void;
36-
// const BackpropFn = *const fn (self: *Self) void;
35+
const BackpropFn = *const fn (self: *Self) void;
3736

3837
const Expr = union(engine.ExprType) {
3938
nop: void,
@@ -73,21 +72,21 @@ pub fn Value(comptime T: type) type {
7372
arena.deinit();
7473
}
7574

76-
/// Create a new Value with no expression
77-
pub fn new(value: T) *Self {
78-
return create(value, .{ .nop = {} });
75+
/// Create a new Scalar value with no expression
76+
pub fn new(scalar: T) *Self {
77+
return create(scalar, .{ .nop = {} });
7978
}
8079

81-
/// Create a new Value with an expression
82-
fn create(value: T, expr: Expr) *Self {
80+
/// Create a new Scalar with an expression
81+
fn create(scalar: T, expr: Expr) *Self {
8382
const v = arena.allocator().create(Self) catch unreachable;
84-
v.* = Self{ .data = value, .grad = @as(T, 0), .expr = expr };
83+
v.* = Self{ .data = scalar, .grad = @as(T, 0), .expr = expr };
8584
return v;
8685
}
8786

88-
// Create a new Value with an unary expression
89-
fn unary(value: T, op: engine.UnaryType, backprop_fn: BackpropFn, arg0: *Self) *Self {
90-
return create(value, Expr{
87+
// Create a new Scalar with an unary expression
88+
fn unary(scalar: T, op: engine.UnaryType, backprop_fn: BackpropFn, arg0: *Self) *Self {
89+
return create(scalar, Expr{
9190
.unary = .{
9291
.op = op,
9392
.backprop_fn = backprop_fn,
@@ -96,9 +95,9 @@ pub fn Value(comptime T: type) type {
9695
});
9796
}
9897

99-
// Create a new Value with a binary expression
100-
fn binary(value: T, op: engine.BinaryType, backprop_fn: BackpropFn, arg0: *Self, arg1: *Self) *Self {
101-
return create(value, Expr{
98+
// Create a new Scalar with a binary expression
99+
fn binary(scalar: T, op: engine.BinaryType, backprop_fn: BackpropFn, arg0: *Self, arg1: *Self) *Self {
100+
return create(scalar, Expr{
102101
.binary = .{
103102
.op = op,
104103
.backprop_fn = backprop_fn,
@@ -116,7 +115,7 @@ pub fn Value(comptime T: type) type {
116115
}
117116
}
118117

119-
/// Add two values
118+
/// Add two Scalars
120119
pub inline fn add(self: *Self, other: *Self) *Self {
121120
return binary(self.data + other.data, .add, add_back, self, other);
122121
}
@@ -127,7 +126,7 @@ pub fn Value(comptime T: type) type {
127126
self.expr.binary.prev[1].grad += self.grad;
128127
}
129128

130-
/// Multiply two values
129+
/// Multiply two Scalars
131130
pub inline fn mul(self: *Self, other: *Self) *Self {
132131
return binary(self.data * other.data, .mul, mul_back, self, other);
133132
}
@@ -138,7 +137,7 @@ pub fn Value(comptime T: type) type {
138137
self.expr.binary.prev[1].grad += self.grad * self.expr.binary.prev[0].data;
139138
}
140139

141-
/// Exponentiate a value
140+
/// Exponentiate a Scalar
142141
pub inline fn exp(self: *Self) *Self {
143142
return unary(std.math.exp(self.data), .exp, exp_back, self);
144143
}
@@ -148,7 +147,7 @@ pub fn Value(comptime T: type) type {
148147
self.expr.unary.prev[0].grad += self.grad * std.math.exp(self.data);
149148
}
150149

151-
/// Subtract two values
150+
/// Subtract two Scalars
152151
pub inline fn sub(self: *Self, other: *Self) *Self {
153152
return binary(self.data - other.data, .sub, sub_back, self, other);
154153
}
@@ -159,7 +158,7 @@ pub fn Value(comptime T: type) type {
159158
self.expr.binary.prev[1].grad -= self.grad;
160159
}
161160

162-
/// Divide two values
161+
/// Divide two Scalars
163162
pub inline fn div(self: *Self, other: *Self) *Self {
164163
return binary(self.data / other.data, .div, div_back, self, other);
165164
}
@@ -170,7 +169,7 @@ pub fn Value(comptime T: type) type {
170169
self.expr.binary.prev[1].grad -= self.grad * self.expr.binary.prev[0].data / (self.expr.binary.prev[1].data * self.expr.binary.prev[1].data);
171170
}
172171

173-
/// Apply the ReLU function to a value
172+
/// Apply the ReLU function to a Scalar
174173
pub inline fn relu(self: *Self) *Self {
175174
return unary(if (self.data > 0) self.data else @as(T, 0), .relu, relu_back, self);
176175
}
@@ -180,7 +179,7 @@ pub fn Value(comptime T: type) type {
180179
self.expr.unary.prev[0].grad += if (self.data > 0) self.grad else @as(T, 0);
181180
}
182181

183-
/// Apply the softmax function to a value
182+
/// Apply the softmax function to a Scalar
184183
pub inline fn softmax(self: *Self) *Self {
185184
return unary(std.math.exp(self.data), .softmax, softmax_back, self);
186185
}
@@ -218,7 +217,7 @@ pub fn Value(comptime T: type) type {
218217

219218
try writer.print(" \"{}\" [label=\"data {s} | grad {s}\", shape=record];\n", .{ node_id, data_str, grad_str });
220219

221-
// If this value is a result of some operation, create an op node for it
220+
// If this Scalar is a result of some operation, create an op node for it
222221
switch (node.expr) {
223222
.nop => {},
224223
.unary, .binary => {

src/nn.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const zprob = @import("zprob");
1515
/// const output = try neuron.forward(&inputs);
1616
/// ```
1717
pub fn Neuron(comptime T: type) type {
18-
const ValueType = engine.Value(T);
18+
const ValueType = engine.Scalar(T);
1919
return struct {
2020
const Self = @This();
2121

@@ -146,7 +146,7 @@ pub fn Neuron(comptime T: type) type {
146146
/// const output = try layer.forward(&inputs);
147147
/// ```
148148
pub fn Layer(comptime T: type) type {
149-
const ValueType = engine.Value(T);
149+
const ValueType = engine.Scalar(T);
150150
const NeuronType = Neuron(T);
151151
return struct {
152152
const Self = @This();
@@ -225,7 +225,7 @@ pub fn Layer(comptime T: type) type {
225225
/// const output = try mlp.forward(&inputs);
226226
/// ```
227227
pub fn MLP(comptime T: type) type {
228-
const ValueType = engine.Value(T);
228+
const ValueType = engine.Scalar(T);
229229
const LayerType = Layer(T);
230230
return struct {
231231
const Self = @This();

0 commit comments

Comments
 (0)