Skip to content

Commit 6b32cfc

Browse files
committed
Print function
1 parent ed3eb28 commit 6b32cfc

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/train.zig

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
//! is to delete this file and start with root.zig instead.
44

55
const std = @import("std");
6-
/// This imports the separate module containing `root.zig`. Take a look in `build.zig` for details.
76
const kiwigrad = @import("kiwigrad");
87

8+
const print = std.debug.print;
9+
910
/// Write the computational graph to a Graphviz file
1011
pub fn draw_graph(comptime T: type, graph: *kiwigrad.engine.Value(T), name: []const u8, writer: anytype) !void {
1112
const dot_name = try std.fmt.allocPrint(std.heap.page_allocator, "{s}.dot", .{name});
@@ -26,6 +27,7 @@ pub fn main() !void {
2627
const stdout_file = std.io.getStdOut().writer();
2728
var bw = std.io.bufferedWriter(stdout_file);
2829
const stdout = bw.writer();
30+
2931
const alloc = std.heap.page_allocator;
3032

3133
// Initialize the required components
@@ -59,7 +61,15 @@ pub fn main() !void {
5961
const output = neuron.forward(input_data[0..]);
6062

6163
// outputs now contains 2 ValueType pointers (one for each neuron)
62-
std.debug.print("Layer output: {d}\n", .{output.data});
64+
print("Layer output: {d:.4}\n", .{output.data});
65+
66+
print("output.data: {d:.4}\n", .{output.data});
67+
print("output.grad: {d:.4}\n", .{output.grad});
68+
69+
output.backwardPass(alloc);
70+
71+
print("output.data: {d:.4}\n", .{output.data});
72+
print("output.grad: {d:.4}\n", .{output.grad});
6373

6474
try draw_graph(f64, output, "n_f64", stdout);
6575
try bw.flush(); // Don't forget to flush!

0 commit comments

Comments
 (0)