-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexp_ImageNet_QBI_Batchnorm.py
More file actions
58 lines (42 loc) · 1.49 KB
/
exp_ImageNet_QBI_Batchnorm.py
File metadata and controls
58 lines (42 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import torch
import torch.nn as nn
import torchvision
from scipy.stats import norm
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from core import multi_evaluate, exp_aggregator, IdentityConv2d
def experiment(num_neurons, batch_size):
transforms = Compose([
Resize(size=256),
CenterCrop(size=(224, 224)),
ToTensor()
])
base_dataset = torchvision.datasets.ImageNet(
root='data/imagenet', split="val", transform=transforms,
)
val_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = nn.Linear(3 * 224 * 224, num_neurons).to(device)
with torch.no_grad():
layer.weight.data.normal_()
optimal_bias = norm.ppf(1 / batch_size) * np.sqrt(3 * 224 * 224)
layer.bias.data.fill_(optimal_bias)
model = IdentityConv2d(layer, 1000)
return multi_evaluate(
model=model,
val_dataloader=val_loader,
batch_size=batch_size,
num_neurons=num_neurons,
eval_iters=10,
batch_norm=True
)
def main():
file_name = 'results_ImageNet_QBI_Batchnorm.csv'
torch.manual_seed(42)
runs_per_setting = 10
layer_sizes = [200, 500, 1000]
batch_sizes = [20, 50, 100, 200]
exp_aggregator(file_name, experiment, layer_sizes, batch_sizes, runs_per_setting)
if __name__ == "__main__":
main()