Skip to content

Commit 0c33cfe

Browse files
committed
Remove bloat
1 parent 06a8693 commit 0c33cfe

File tree

3 files changed

+9
-27
lines changed

3 files changed

+9
-27
lines changed

examples/train.zig

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,10 @@
44

55
const std = @import("std");
66
const kiwigrad = @import("kiwigrad");
7+
const zbench = @import("zbench");
78

89
const print = std.debug.print;
910

10-
/// Write the computational graph to a Graphviz file
11-
pub fn draw_graph(comptime T: type, graph: *kiwigrad.engine.Value(T), name: []const u8, writer: anytype) !void {
12-
const dot_name = try std.fmt.allocPrint(std.heap.page_allocator, "{s}.dot", .{name});
13-
defer std.heap.page_allocator.free(dot_name);
14-
const png_name = try std.fmt.allocPrint(std.heap.page_allocator, "{s}.png", .{name});
15-
defer std.heap.page_allocator.free(png_name);
16-
17-
const file = try std.fs.cwd().createFile(dot_name, .{});
18-
defer file.close();
19-
const file_writer = file.writer();
20-
try graph.draw_dot(file_writer, std.heap.page_allocator);
21-
22-
try writer.print("Computational graph written to {s}\n", .{dot_name});
23-
try writer.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name });
24-
}
25-
2611
pub fn main() !void {
2712
const stdout_file = std.io.getStdOut().writer();
2813
var bw = std.io.bufferedWriter(stdout_file);
@@ -71,6 +56,6 @@ pub fn main() !void {
7156
print("output.data: {d:.4}\n", .{output.data});
7257
print("output.grad: {d:.4}\n", .{output.grad});
7358

74-
try draw_graph(f64, output, "assets/img/train", stdout);
59+
output.draw_graph("assets/img/train", stdout);
7560
try bw.flush(); // Don't forget to flush!
7661
}

src/engine.zig

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,18 @@ pub fn Value(comptime T: type) type {
341341

342342
/// Write the computational graph to a Graphviz file
343343
pub fn draw_graph(graph: *Self, name: []const u8, writer: anytype) void {
344-
const dot_name = try std.fmt.allocPrint(std.heap.page_allocator, "{s}.dot", .{name});
344+
const dot_name = std.fmt.allocPrint(std.heap.page_allocator, "{s}.dot", .{name}) catch unreachable;
345345
defer std.heap.page_allocator.free(dot_name);
346-
const png_name = try std.fmt.allocPrint(std.heap.page_allocator, "{s}.png", .{name});
346+
const png_name = std.fmt.allocPrint(std.heap.page_allocator, "{s}.png", .{name}) catch unreachable;
347347
defer std.heap.page_allocator.free(png_name);
348348

349-
const file = try std.fs.cwd().createFile(dot_name, .{});
349+
const file = std.fs.cwd().createFile(dot_name, .{}) catch unreachable;
350350
defer file.close();
351351
const file_writer = file.writer();
352-
graph.draw_dot(file_writer, std.heap.page_allocator) catch |err| {
353-
std.debug.print("Failed to write dot file: {}\n", .{err});
354-
return;
355-
};
352+
graph.draw_dot(file_writer, std.heap.page_allocator) catch unreachable;
356353

357-
try writer.print("Computational graph written to {s}\n", .{dot_name});
358-
try writer.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name });
354+
writer.print("Computational graph written to {s}\n", .{dot_name}) catch unreachable;
355+
writer.print("You can visualize it by running: dot -Tpng {s} -o {s}\n", .{ dot_name, png_name }) catch unreachable;
359356
}
360357
};
361358
}

src/nn.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ pub fn Layer(comptime T: type) type {
164164
}
165165

166166
layer.* = Self{
167+
.nin = nin,
167168
.neurons = neurons[0..],
168169
.nout = nout,
169170
};
@@ -174,7 +175,6 @@ pub fn Layer(comptime T: type) type {
174175
/// Forward pass through the layer
175176
pub fn forward(self: *Self, inputs: []*ValueType) []*ValueType {
176177
var list = arena.allocator().alloc(*ValueType, self.nout) catch unreachable;
177-
defer arena.allocator().free(list);
178178
for (self.neurons, 0..) |neuron, i| {
179179
list[i] = neuron.forward(inputs);
180180
}

0 commit comments

Comments
 (0)