-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathensemble_results.py
More file actions
153 lines (117 loc) · 4.58 KB
/
ensemble_results.py
File metadata and controls
153 lines (117 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json
import os
import argparse
from collections import defaultdict
import torch
class Evaluator:
def __init__(self, metrics, topk):
self.metric2func = {
'recall': self.recall_at_k,
'ndcg': self.ndcg_at_k
}
self.metrics = metrics
self.topk = topk
self.maxk = max(topk)
def ensemble_results(self, results_info_list):
labels = results_info_list[0]['label_ids']
N = len(labels)
preds = [defaultdict(list) for _ in range(N)]
for results_info in results_info_list:
pred_ids = results_info['pred_ids']
scores = results_info['scores']
label_ids = results_info['label_ids']
for i in range(N):
item_list = pred_ids[i]
score_list = scores[i]
assert label_ids[i] == labels[i]
for item, score in zip(item_list, score_list):
preds[i][item].append(score)
final_preds = []
for i in range(N):
pred = defaultdict(float)
for item, score_list in preds[i].items():
pred[item] = sum(score_list) / len(score_list)
if item!= "None":
pred[item] += len(score_list)
sorted_items = sorted(pred.items(), key=lambda x: x[1], reverse=True)
# if i==0:
# print(sorted_items)
pred = [item for item, _ in sorted_items]
if len(pred) < self.maxk:
pred += ["None"] * (self.maxk - len(pred))
final_preds.append(pred)
return final_preds, labels
def calculate_pos_index(self, preds, labels):
N = len(preds)
pos_index = torch.zeros((N, self.maxk), dtype=torch.bool)
for i in range(N):
cur_label = labels[i]
for j in range(self.maxk):
try:
cur_pred = preds[i][j]
except Exception as e:
print(e)
print(i)
print(preds[i])
print(j)
raise RuntimeError
if cur_pred == cur_label:
pos_index[i, j] = True
break
return pos_index
def recall_at_k(self, pos_index, k):
return pos_index[:, :k].sum(dim=1).cpu().float()
def ndcg_at_k(self, pos_index, k):
# Assume only one ground truth item per example
ranks = torch.arange(1, pos_index.shape[-1] + 1).to(pos_index.device)
dcg = 1.0 / torch.log2(ranks + 1)
dcg = torch.where(pos_index, dcg, 0)
return dcg[:, :k].sum(dim=1).cpu().float()
def calculate_metrics(self, preds, labels):
results = {}
pos_index = self.calculate_pos_index(preds, labels)
for metric in self.metrics:
for k in self.topk:
results[f"{metric}@{k}"] = self.metric2func[metric](pos_index, k)
results = {k: v.mean().item() for k, v in results.items()}
return results
def main(args):
metrics = args.metrics
topk = args.topk
evaluator = Evaluator(metrics, topk)
# temp = []
results_file_list = os.listdir(args.results_dir)
results_file_list = [f for f in results_file_list if f.endswith('.json')]
results_file_list.sort()
results_info_list = []
max_results = {}
for results_file in results_file_list:
results_info = json.load(open(os.path.join(args.results_dir, results_file), 'r'))
results_info_list.append(results_info)
temp_res = evaluator.calculate_metrics(results_info['pred_ids'], results_info['label_ids'])
print(f"Results for {results_file}:")
print(temp_res)
for k, v in temp_res.items():
if k not in max_results:
max_results[k] = v
else:
max_results[k] = max(max_results[k], v)
print("Max Results:")
print(max_results)
preds, labels = evaluator.ensemble_results(results_info_list)
metrics_results = evaluator.calculate_metrics(preds, labels)
print("Ensemble Results:")
print(metrics_results)
def parse_args():
parser = argparse.ArgumentParser(description="Index")
parser.add_argument("--results_dir", type=str,
default="./results/Video_Games/"
)
parser.add_argument("--metrics", type=str, default="ndcg,recall")
parser.add_argument("--topk", type=str, default="5,10")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
args.metrics = args.metrics.split(",")
args.topk = list(map(int, args.topk.split(",")))
main(args)