-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathImagnette_ResNet.py
More file actions
51 lines (39 loc) · 1.8 KB
/
Imagnette_ResNet.py
File metadata and controls
51 lines (39 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import torch
from torchvision import transforms
from modules.imagenette import Imagenette
from modules.parser import DATA_ROOT
from modules.resnet import resnet18
class SquarePad:
def __call__(self, image): # fail as longest side is reduced, not shortest
w, h = image.size
max_wh = np.max([w, h])
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (hp, vp, hp, vp)
return transforms.functional.pad(image, padding, 0, 'constant')
CHECKPOINTS_FOLDER = DATA_ROOT / "checkpoints/resnette_single"
train_set = Imagenette(root=DATA_ROOT / 'data', split="train", download=True, size="320px",
transform=transforms.Compose([transforms.CenterCrop(320), transforms.ToTensor()]))
test_set = Imagenette(root=DATA_ROOT / 'data', split="val", download=True, size="320px",
transform=transforms.Compose([transforms.CenterCrop(320), transforms.ToTensor()]))
# Dataloaders
#train_set, fake_test_set = random_split(train_set, [0.8, 0.2])
#train_set, val_set = random_split(train_set, [0.9, 0.1])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=6)
#val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=6)
NAME = "model"
EXTENSION = ".pt"
Model = resnet18()
start = 3
end = 17
num_values = 37
num_slice = 27
old_ln_space = [start + i * (end - start) / (num_values - 1) for i in range(num_slice)]
imagenette_galphas = [10**el for el in old_ln_space]
diff = (old_ln_space[5]-old_ln_space[0])/5
#imagenette_galphas = [10**(el - diff/2) for el in old_ln_space]
#
#imagenette_galphas += [10**el for el in old_ln_space]
#imagenette_galphas.sort()