-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_saliency_maps.py
More file actions
54 lines (47 loc) · 2.51 KB
/
plot_saliency_maps.py
File metadata and controls
54 lines (47 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import quantus
import torch
from matplotlib import pyplot as plt
from main import test_set, CHECKPOINTS_FOLDER, Model
from modules.constants import OPTIONS
from saliency import extract_alpha_from_filename
from modules.parser import DATA_ROOT
for option in OPTIONS:
print(option)
models_to_compare = ("regular_double_backprop_finetune_100.0.pt", "regular_double_backprop_finetune_11513953.993264457.pt", "regular_double_backprop_finetune_10000000000.0.pt")
if option == "KMNIST_LeNet":
from KMNIST_LeNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
elif option == "FMNIST_LeNet":
from FMNIST_LeNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
elif option == "MNIST_LeNet":
from MNIST_LeNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
elif option == "MNIST_ResNet":
from MNIST_ResNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
elif option == "CIFAR_ResNet":
from CIFAR_ResNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
elif option == "Imagenette_ResNet":
from Imagnette_ResNet import test_set, Model, CHECKPOINTS_FOLDER as base_folder
models_to_compare = ("regular_double_backprop_finetune_1000.0.pt", "regular_double_backprop_finetune_10000000000.0.pt", "regular_double_backprop_finetune_12915496650148.828.pt")
else:
raise NotImplementedError
INPUT_DIRECTORY = base_folder
#if __name__ == "__main__":
output_dir = DATA_ROOT / "auto_saved_saliency" / "visual_eval"
# Dataloaders
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False)
for inputs, labels in test_loader:
plt.imshow(inputs[0].numpy().transpose((1, 2, 0)))
plt.title(test_set.classes[labels[0]])
output_class_dir = output_dir / '_'.join(test_set.classes[labels[0]])
output_class_dir.mkdir(parents=True, exist_ok=True)
for model_file in models_to_compare:
plt.figure()
alpha = extract_alpha_from_filename(model_file)
model = Model
model.load_state_dict(torch.load(INPUT_DIRECTORY/model_file))
saliency = quantus.explain(model, inputs, labels, method="DeepLift")
plt.imshow(saliency[0].transpose((1, 2, 0)), cmap='hot')
plt.title(f"{alpha:e} model thinks {test_set.classes[model(inputs)[0].argmax()]}")
save_path = output_class_dir / (model_file.split("_")[-1][:-3] + ".png")
plt.savefig(save_path)
plt.show()
break