-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsaliency.py
More file actions
107 lines (81 loc) · 3.67 KB
/
saliency.py
File metadata and controls
107 lines (81 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from pathlib import Path
import numpy as np
import quantus
import torch
from tqdm import tqdm
from main import test_set, CHECKPOINTS_FOLDER, Model
from modules.constants import SALIENCY_FOLDER_NAME
from modules.faithfulness_estimate import BatchedFaithfulnessEstimate
from modules.parser import DATA_ROOT, OPTION, ITERATOR_VALUE, SALIENCY_METHOD
from modules.perturb_func import flip_perturb_func
# Enable GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class AccuracyMetric:
name = "accuracy"
def __call__(self, model, x_batch, y_batch, device, **kwargs):
outputs = model(torch.tensor(x_batch, device=device))
_, predicted = torch.max(outputs.data, 1)
return (predicted.cpu().numpy() == y_batch).sum().item() / y_batch.size
def evaluate_metric(
model, metric, dataloader, n_batches=1, average=True, method=SALIENCY_METHOD
): # requires shuffled dataloader
print("evaluating", metric.name)
scores = []
for i, (inputs, labels) in enumerate(dataloader):
if i == n_batches:
break
scores.append(metric(
model=model,
x_batch=inputs.cpu().numpy(),
y_batch=labels.cpu().numpy(),
device="cuda",
explain_func=quantus.explain,
explain_func_kwargs={"method": method} #Deeplift, integrated gradients
))
scores = np.average(scores, axis=0) # always averages over multiple batches
return np.average(scores, axis=tuple(range(1, scores.ndim)))
def evaluate_metrics(model, metrics, dataloader, n_batches=1):
return {metric.name: evaluate_metric(model, metric, dataloader, n_batches) for metric in metrics}
def extract_alpha_from_filename(filename):
return float(filename.split('_')[-1].split('.')[0])
def get_models_in_file(directory_path: Path):
for file in sorted(directory_path.iterdir()):
if ".pt" in file.name:
yield file
USED_METRICS = (
BatchedFaithfulnessEstimate(perturb_func=flip_perturb_func), quantus.LocalLipschitzEstimate(), quantus.MaxSensitivity(), quantus.Complexity()
)
INPUT_DIRECTORY = CHECKPOINTS_FOLDER
OUTPUT_DIRECTORY = DATA_ROOT / SALIENCY_FOLDER_NAME / OPTION / SALIENCY_METHOD
CACHE = False
if __name__ == "__main__":
OUTPUT_DIRECTORY.mkdir(parents=True, exist_ok=True)
if ITERATOR_VALUE is None:
assert not any(OUTPUT_DIRECTORY.iterdir()), "output directory exists and is not empty"
print(f"INPUT_DIRECTORY:{INPUT_DIRECTORY}")
print(f"OUTPUT_DIRECTORY:{OUTPUT_DIRECTORY}")
evaluation_batch_size = 128
# Dataloaders
test_loader = torch.utils.data.DataLoader(test_set, batch_size=evaluation_batch_size, shuffle=True)
for i, file2 in enumerate(list(get_models_in_file(INPUT_DIRECTORY))):
print(i)
output_path = (OUTPUT_DIRECTORY / file2.name)
if ITERATOR_VALUE is not None:
if i != ITERATOR_VALUE:
continue
assert not output_path.exists(), "output path already exists"
print(file2)
model = Model.to(device)
#model.load_state_dict(torch.load(file1))
#result = evaluate_metrics(model, USED_METRICS, test_loader)
#torch.save(result, OUTPUT_DIRECTORY / file1.name)
#print("intermediate", result)
model.load_state_dict(torch.load(file2))
result = evaluate_metrics(model, USED_METRICS, test_loader)
torch.save(result, output_path)
print("first layer backprop", result)
#model.load_state_dict(torch.load(file3))
#result = evaluate_metrics(model, USED_METRICS, test_loader)
#torch.save(result, OUTPUT_DIRECTORY / file3.name)
#print("all layers backprop", result)
print("")