-
Notifications
You must be signed in to change notification settings - Fork 222
Expand file tree
/
Copy pathAlexNet.cs
More file actions
74 lines (63 loc) · 2.62 KB
/
Copy pathAlexNet.cs
File metadata and controls
74 lines (63 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
namespace TorchSharp.Examples
{
/// <summary>
/// Modified version of original AlexNet to fix CIFAR10 32x32 images.
/// </summary>
class AlexNet : Module<Tensor, Tensor>
{
private readonly Module<Tensor, Tensor> features;
private readonly Module<Tensor, Tensor> avgPool;
private readonly Module<Tensor, Tensor> classifier;
public AlexNet(string name, int numClasses, torch.Device device = null) : base(name)
{
features = Sequential(
("c1", Conv2d(3, 64, kernel_size: 3, stride: 2, padding: 1)),
("r1", ReLU(inplace: true)),
("mp1", MaxPool2d(kernel_size: new long[] { 2, 2 })),
("c2", Conv2d(64, 192, kernel_size: 3, padding: 1)),
("r2", ReLU(inplace: true)),
("mp2", MaxPool2d(kernel_size: new long[] { 2, 2 })),
("c3", Conv2d(192, 384, kernel_size: 3, padding: 1)),
("r3", ReLU(inplace: true)),
("c4", Conv2d(384, 256, kernel_size: 3, padding: 1)),
("r4", ReLU(inplace: true)),
("c5", Conv2d(256, 256, kernel_size: 3, padding: 1)),
("r5", ReLU(inplace: true)),
("mp3", MaxPool2d(kernel_size: new long[] { 2, 2 })));
avgPool = AdaptiveAvgPool2d(new long[] { 2, 2 });
classifier = Sequential(
("d1", Dropout()),
("l1", Linear(256 * 2 * 2, 4096)),
("r1", ReLU(inplace: true)),
("d2", Dropout()),
("l2", Linear(4096, 4096)),
("r3", ReLU(inplace: true)),
("d3", Dropout()),
("l3", Linear(4096, numClasses))
);
RegisterComponents();
if (device != null && device.type != DeviceType.CPU)
this.to(device);
}
public override Tensor forward(Tensor input)
{
var f = features.call(input);
var avg = avgPool.call(f);
using (var x = avg.reshape(new long[] { avg.shape[0], 256 * 2 * 2 }))
return classifier.call(x);
}
protected override void Dispose(bool disposing)
{
if (disposing) {
features.Dispose();
avgPool.Dispose();
classifier.Dispose();
ClearModules();
}
base.Dispose(disposing);
}
}
}