-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexp_CIFAR-10_PAIRS.py
More file actions
70 lines (51 loc) · 1.93 KB
/
exp_CIFAR-10_PAIRS.py
File metadata and controls
70 lines (51 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import torch
import torch.nn as nn
import torchvision
from scipy.stats import norm
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
from core import PAIRS, multi_evaluate, exp_aggregator, IdentityConv2d
def experiment(num_neurons, batch_size):
transforms = Compose([
ToTensor(),
Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])
base_dataset = torchvision.datasets.CIFAR10(
root='data', train=True, transform=transforms, download=True
)
indices = torch.randperm(len(base_dataset)).tolist()
split_index = len(base_dataset) // 2
train_dataset = torch.utils.data.Subset(base_dataset, indices[:split_index])
val_dataset = torch.utils.data.Subset(base_dataset, indices[split_index:])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer = nn.Linear(3 * 32 * 32, num_neurons).to(device)
with torch.no_grad():
layer.weight.data.normal_()
optimal_bias = norm.ppf(1 / batch_size) * np.sqrt(3 * 32 * 32)
layer.bias.data.fill_(optimal_bias)
model = IdentityConv2d(layer, 10)
PAIRS(
layer=model.fc1,
train_dataloader=train_loader,
batch_size=batch_size,
n_neurons=num_neurons,
)
return multi_evaluate(
model=model,
val_dataloader=val_loader,
batch_size=batch_size,
num_neurons=num_neurons,
eval_iters=10
)
def main():
file_name = 'results_CIFAR-10_PAIRS.csv'
torch.manual_seed(42)
runs_per_setting = 10
layer_sizes = [200, 500, 1000]
batch_sizes = [20, 50, 100, 200]
exp_aggregator(file_name, experiment, layer_sizes, batch_sizes, runs_per_setting)
if __name__ == "__main__":
main()