Skip to content

Commit e898167

Browse files
committed
Value -> Scalar
1 parent 205392f commit e898167

File tree

4 files changed

+30
-29
lines changed

4 files changed

+30
-29
lines changed

examples/train.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub fn main() !void {
1414
const alloc = std.heap.page_allocator;
1515

1616
// Initialize the required components
17-
const ValueType = kiwigrad.engine.Value(f64);
17+
const ValueType = kiwigrad.engine.Scalar(f64);
1818
const NeuronType = kiwigrad.nn.Neuron(f64);
1919
const LayerType = kiwigrad.nn.Layer(f64);
2020
const MLPType = kiwigrad.nn.MLP(f64);

src/engine/engine.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ pub const BinaryType = enum {
3939
}
4040
};
4141

42-
pub const Value = @import("value.zig").Value;
42+
pub const Scalar = @import("scalar.zig").Scalar;
43+
// pub const Tensor = @import("tensor.zig").Tensor;
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ const engine = @import("engine.zig");
99
///
1010
/// # Example
1111
/// ```zig
12-
/// const Value = @import("engine").Value;
13-
/// const value = Value(f32).new(2.0);
12+
/// const Scalar = @import("engine").Scalar;
13+
/// const Scalar = Scalar(f32).new(2.0);
1414
/// ```
1515
///
1616
/// # Operations
1717
///
18-
/// The Value type supports the following operations:
18+
/// The Scalar type supports the following operations:
1919
///
2020
/// - Addition
2121
/// - Subtraction
@@ -24,7 +24,7 @@ const engine = @import("engine.zig");
2424
/// - Division
2525
/// - Rectified Linear Unit (ReLU)
2626
/// - Softmax
27-
pub fn Value(comptime T: type) type {
27+
pub fn Scalar(comptime T: type) type {
2828
// Check that T is a valid type
2929
if (@typeInfo(T) != .int and @typeInfo(T) != .float) {
3030
@compileError("Expected @int or @float type, got: " ++ @typeName(T));
@@ -73,21 +73,21 @@ pub fn Value(comptime T: type) type {
7373
arena.deinit();
7474
}
7575

76-
/// Create a new Value with no expression
77-
pub fn new(value: T) *Self {
78-
return create(value, .{ .nop = {} });
76+
/// Create a new Scalar value with no expression
77+
pub fn new(scalar: T) *Self {
78+
return create(scalar, .{ .nop = {} });
7979
}
8080

81-
/// Create a new Value with an expression
82-
fn create(value: T, expr: Expr) *Self {
81+
/// Create a new Scalar with an expression
82+
fn create(scalar: T, expr: Expr) *Self {
8383
const v = arena.allocator().create(Self) catch unreachable;
84-
v.* = Self{ .data = value, .grad = @as(T, 0), .expr = expr };
84+
v.* = Self{ .data = scalar, .grad = @as(T, 0), .expr = expr };
8585
return v;
8686
}
8787

88-
// Create a new Value with an unary expression
89-
fn unary(value: T, op: engine.UnaryType, backprop_fn: BackpropFn, arg0: *Self) *Self {
90-
return create(value, Expr{
88+
// Create a new Scalar with an unary expression
89+
fn unary(scalar: T, op: engine.UnaryType, backprop_fn: BackpropFn, arg0: *Self) *Self {
90+
return create(scalar, Expr{
9191
.unary = .{
9292
.op = op,
9393
.backprop_fn = backprop_fn,
@@ -96,9 +96,9 @@ pub fn Value(comptime T: type) type {
9696
});
9797
}
9898

99-
// Create a new Value with a binary expression
100-
fn binary(value: T, op: engine.BinaryType, backprop_fn: BackpropFn, arg0: *Self, arg1: *Self) *Self {
101-
return create(value, Expr{
99+
// Create a new Scalar with a binary expression
100+
fn binary(scalar: T, op: engine.BinaryType, backprop_fn: BackpropFn, arg0: *Self, arg1: *Self) *Self {
101+
return create(scalar, Expr{
102102
.binary = .{
103103
.op = op,
104104
.backprop_fn = backprop_fn,
@@ -116,7 +116,7 @@ pub fn Value(comptime T: type) type {
116116
}
117117
}
118118

119-
/// Add two values
119+
/// Add two Scalars
120120
pub inline fn add(self: *Self, other: *Self) *Self {
121121
return binary(self.data + other.data, .add, add_back, self, other);
122122
}
@@ -127,7 +127,7 @@ pub fn Value(comptime T: type) type {
127127
self.expr.binary.prev[1].grad += self.grad;
128128
}
129129

130-
/// Multiply two values
130+
/// Multiply two Scalars
131131
pub inline fn mul(self: *Self, other: *Self) *Self {
132132
return binary(self.data * other.data, .mul, mul_back, self, other);
133133
}
@@ -138,7 +138,7 @@ pub fn Value(comptime T: type) type {
138138
self.expr.binary.prev[1].grad += self.grad * self.expr.binary.prev[0].data;
139139
}
140140

141-
/// Exponentiate a value
141+
/// Exponentiate a Scalar
142142
pub inline fn exp(self: *Self) *Self {
143143
return unary(std.math.exp(self.data), .exp, exp_back, self);
144144
}
@@ -148,7 +148,7 @@ pub fn Value(comptime T: type) type {
148148
self.expr.unary.prev[0].grad += self.grad * std.math.exp(self.data);
149149
}
150150

151-
/// Subtract two values
151+
/// Subtract two Scalars
152152
pub inline fn sub(self: *Self, other: *Self) *Self {
153153
return binary(self.data - other.data, .sub, sub_back, self, other);
154154
}
@@ -159,7 +159,7 @@ pub fn Value(comptime T: type) type {
159159
self.expr.binary.prev[1].grad -= self.grad;
160160
}
161161

162-
/// Divide two values
162+
/// Divide two Scalars
163163
pub inline fn div(self: *Self, other: *Self) *Self {
164164
return binary(self.data / other.data, .div, div_back, self, other);
165165
}
@@ -170,7 +170,7 @@ pub fn Value(comptime T: type) type {
170170
self.expr.binary.prev[1].grad -= self.grad * self.expr.binary.prev[0].data / (self.expr.binary.prev[1].data * self.expr.binary.prev[1].data);
171171
}
172172

173-
/// Apply the ReLU function to a value
173+
/// Apply the ReLU function to a Scalar
174174
pub inline fn relu(self: *Self) *Self {
175175
return unary(if (self.data > 0) self.data else @as(T, 0), .relu, relu_back, self);
176176
}
@@ -180,7 +180,7 @@ pub fn Value(comptime T: type) type {
180180
self.expr.unary.prev[0].grad += if (self.data > 0) self.grad else @as(T, 0);
181181
}
182182

183-
/// Apply the softmax function to a value
183+
/// Apply the softmax function to a Scalar
184184
pub inline fn softmax(self: *Self) *Self {
185185
return unary(std.math.exp(self.data), .softmax, softmax_back, self);
186186
}
@@ -218,7 +218,7 @@ pub fn Value(comptime T: type) type {
218218

219219
try writer.print(" \"{}\" [label=\"data {s} | grad {s}\", shape=record];\n", .{ node_id, data_str, grad_str });
220220

221-
// If this value is a result of some operation, create an op node for it
221+
// If this Scalar is a result of some operation, create an op node for it
222222
switch (node.expr) {
223223
.nop => {},
224224
.unary, .binary => {

src/nn.zig

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const zprob = @import("zprob");
1515
/// const output = try neuron.forward(&inputs);
1616
/// ```
1717
pub fn Neuron(comptime T: type) type {
18-
const ValueType = engine.Value(T);
18+
const ValueType = engine.Scalar(T);
1919
return struct {
2020
const Self = @This();
2121

@@ -146,7 +146,7 @@ pub fn Neuron(comptime T: type) type {
146146
/// const output = try layer.forward(&inputs);
147147
/// ```
148148
pub fn Layer(comptime T: type) type {
149-
const ValueType = engine.Value(T);
149+
const ValueType = engine.Scalar(T);
150150
const NeuronType = Neuron(T);
151151
return struct {
152152
const Self = @This();
@@ -225,7 +225,7 @@ pub fn Layer(comptime T: type) type {
225225
/// const output = try mlp.forward(&inputs);
226226
/// ```
227227
pub fn MLP(comptime T: type) type {
228-
const ValueType = engine.Value(T);
228+
const ValueType = engine.Scalar(T);
229229
const LayerType = Layer(T);
230230
return struct {
231231
const Self = @This();

0 commit comments

Comments
 (0)