Skip to content

Commit 5f0eb5f

Browse files
committed
Add variance for the activation function
1 parent 2735964 commit 5f0eb5f

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

src/engine/engine.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ pub const UnaryType = enum {
1212
exp,
1313
relu,
1414
softmax,
15+
identity,
1516

1617
pub fn toString(self: UnaryType) []const u8 {
1718
return switch (self) {
1819
.tanh => "tanh",
1920
.exp => "^",
2021
.relu => "ReLU",
2122
.softmax => "Softmax",
23+
.identity => "id",
2224
};
2325
}
2426
};

src/engine/scalar.zig

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const engine = @import("engine.zig");
2424
/// - Division
2525
/// - Rectified Linear Unit (ReLU)
2626
/// - Softmax
27+
/// - Tanh
2728
pub fn Scalar(comptime T: type) type {
2829
// Check that T is a valid type
2930
switch (@typeInfo(T)) {
@@ -178,6 +179,26 @@ pub fn Scalar(comptime T: type) type {
178179
self.expr.unary.prev[0].grad += self.grad * std.math.exp(self.data);
179180
}
180181

182+
/// Apply the tanh function to a Scalar
183+
pub inline fn tanh(self: *Self) *Self {
184+
return unary(std.math.tanh(self.data), .tanh, tanh_back, self);
185+
}
186+
187+
/// Backpropagation function for tanh
188+
fn tanh_back(self: *Self) void {
189+
self.expr.unary.prev[0].grad += self.grad * (@as(T, 1) - self.data * self.data);
190+
}
191+
192+
/// Identity function (passes value through, creates new node in graph)
193+
pub inline fn identity(self: *Self) *Self {
194+
return unary(self.data, .identity, identity_back, self);
195+
}
196+
197+
/// Backpropagation function for identity
198+
fn identity_back(self: *Self) void {
199+
self.expr.unary.prev[0].grad += self.grad;
200+
}
201+
181202
/// Generate Graphviz DOT format representation of the computational graph
182203
pub fn draw_dot(self: *Self, writer: anytype, allocator: std.mem.Allocator) !void {
183204
// First, trace all nodes and edges in the graph

src/nn.zig

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ const std = @import("std");
44
const engine = @import("engine/engine.zig");
55
const zprob = @import("zprob");
66

7+
/// Activation functions
8+
pub const ActivationType = engine.UnaryType;
9+
710
/// Represents a neuron with a configurable input size
811
///
912
/// This is a generic type that can be used to create a neuron with configurable input size.
@@ -85,7 +88,7 @@ pub fn Neuron(comptime T: type) type {
8588
}
8689

8790
/// Forward pass through the neuron
88-
pub fn forward(self: *Self, inputs: []*ValueType) *ValueType {
91+
pub fn forward(self: *Self, inputs: []*ValueType, activation: ActivationType) *ValueType {
8992
if (inputs.len != self.nin) {
9093
std.debug.panic("Input size mismatch: {d} != {d}", .{ inputs.len, self.nin });
9194
}
@@ -94,8 +97,14 @@ pub fn Neuron(comptime T: type) type {
9497
for (self.weights, inputs) |w, x| {
9598
sum = sum.add(w.mul(x));
9699
}
97-
// Apply activation function (ReLU)
98-
return sum.relu();
100+
// Apply activation function
101+
return switch (activation) {
102+
.relu => sum.relu(),
103+
.identity => sum.identity(),
104+
.tanh => sum.tanh(),
105+
.softmax => sum.softmax(),
106+
else => std.debug.panic("Invalid activation function: {s}", .{@tagName(activation)}),
107+
};
99108
}
100109

101110
/// Get all parameters (weights and bias) for optimization
@@ -173,10 +182,10 @@ pub fn Layer(comptime T: type) type {
173182
}
174183

175184
/// Forward pass through the layer
176-
pub fn forward(self: *Self, inputs: []*ValueType) []*ValueType {
185+
pub fn forward(self: *Self, inputs: []*ValueType, activation: ActivationType) []*ValueType {
177186
var list = arena.allocator().alloc(*ValueType, self.nout) catch unreachable;
178187
for (self.neurons, 0..) |neuron, i| {
179-
list[i] = neuron.forward(inputs);
188+
list[i] = neuron.forward(inputs, activation);
180189
}
181190
return list;
182191
}
@@ -257,8 +266,19 @@ pub fn MLP(comptime T: type) type {
257266
/// Forward pass through the layer
258267
pub fn forward(self: *Self, inputs: []*ValueType) []*ValueType {
259268
var current_inputs = inputs;
260-
for (self.layers) |layer| {
261-
current_inputs = layer.forward(current_inputs);
269+
// Process all layers except the last one with tanh (ReLU can kill gradients)
270+
for (self.layers[0 .. self.layers.len - 1]) |layer| {
271+
current_inputs = layer.forward(current_inputs, ActivationType.tanh);
272+
}
273+
// Last layer: use linear activation for regression (no activation)
274+
// For classification with multiple outputs, use softmax instead
275+
const last_layer = self.layers[self.layers.len - 1];
276+
if (last_layer.nout == 1) {
277+
// Single output: use linear activation for regression
278+
current_inputs = last_layer.forward(current_inputs, ActivationType.identity);
279+
} else {
280+
// Multiple outputs: use softmax for classification
281+
current_inputs = last_layer.forward(current_inputs, ActivationType.softmax);
262282
}
263283
return current_inputs;
264284
}

0 commit comments

Comments
 (0)