@@ -4,6 +4,9 @@ const std = @import("std");
44const engine = @import ("engine/engine.zig" );
55const zprob = @import ("zprob" );
66
7+ /// Activation functions
8+ pub const ActivationType = engine .UnaryType ;
9+
710/// Represents a neuron with a configurable input size
811///
912/// This is a generic type that can be used to create a neuron with configurable input size.
@@ -85,7 +88,7 @@ pub fn Neuron(comptime T: type) type {
8588 }
8689
8790 /// Forward pass through the neuron
88- pub fn forward (self : * Self , inputs : []* ValueType ) * ValueType {
91+ pub fn forward (self : * Self , inputs : []* ValueType , activation : ActivationType ) * ValueType {
8992 if (inputs .len != self .nin ) {
9093 std .debug .panic ("Input size mismatch: {d} != {d}" , .{ inputs .len , self .nin });
9194 }
@@ -94,8 +97,14 @@ pub fn Neuron(comptime T: type) type {
9497 for (self .weights , inputs ) | w , x | {
9598 sum = sum .add (w .mul (x ));
9699 }
97- // Apply activation function (ReLU)
98- return sum .relu ();
100+ // Apply activation function
101+ return switch (activation ) {
102+ .relu = > sum .relu (),
103+ .identity = > sum .identity (),
104+ .tanh = > sum .tanh (),
105+ .softmax = > sum .softmax (),
106+ else = > std .debug .panic ("Invalid activation function: {s}" , .{@tagName (activation )}),
107+ };
99108 }
100109
101110 /// Get all parameters (weights and bias) for optimization
@@ -173,10 +182,10 @@ pub fn Layer(comptime T: type) type {
173182 }
174183
175184 /// Forward pass through the layer
176- pub fn forward (self : * Self , inputs : []* ValueType ) []* ValueType {
185+ pub fn forward (self : * Self , inputs : []* ValueType , activation : ActivationType ) []* ValueType {
177186 var list = arena .allocator ().alloc (* ValueType , self .nout ) catch unreachable ;
178187 for (self .neurons , 0.. ) | neuron , i | {
179- list [i ] = neuron .forward (inputs );
188+ list [i ] = neuron .forward (inputs , activation );
180189 }
181190 return list ;
182191 }
@@ -257,8 +266,19 @@ pub fn MLP(comptime T: type) type {
257266 /// Forward pass through the layer
258267 pub fn forward (self : * Self , inputs : []* ValueType ) []* ValueType {
259268 var current_inputs = inputs ;
260- for (self .layers ) | layer | {
261- current_inputs = layer .forward (current_inputs );
269+ // Process all layers except the last one with tanh (ReLU can kill gradients)
270+ for (self .layers [0 .. self .layers .len - 1 ]) | layer | {
271+ current_inputs = layer .forward (current_inputs , ActivationType .tanh );
272+ }
273+ // Last layer: use linear activation for regression (no activation)
274+ // For classification with multiple outputs, use softmax instead
275+ const last_layer = self .layers [self .layers .len - 1 ];
276+ if (last_layer .nout == 1 ) {
277+ // Single output: use linear activation for regression
278+ current_inputs = last_layer .forward (current_inputs , ActivationType .identity );
279+ } else {
280+ // Multiple outputs: use softmax for classification
281+ current_inputs = last_layer .forward (current_inputs , ActivationType .softmax );
262282 }
263283 return current_inputs ;
264284 }
0 commit comments