-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluator.py
More file actions
84 lines (70 loc) · 3.11 KB
/
evaluator.py
File metadata and controls
84 lines (70 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
class Evaluator:
def __init__(self, config, tokenizer):
self.config = config
self.tokenizer = tokenizer
self.metric2func = {
'recall': self.recall_at_k,
'ndcg': self.ndcg_at_k
}
self.eos_token = self.tokenizer.eos_token
self.maxk = max(config['topk'])
def calculate_pos_index(self, preds, labels):
# print("preds is",preds )
# print("labels is ",labels)
# print(f"shape {preds.shape},labels shpae is {labels.shape}")
preds = preds.detach().cpu()
labels = labels.detach().cpu()
assert preds.shape[1] == self.maxk, f"preds.shape[1] = {preds.shape[1]} != {self.maxk}"
pos_index = torch.zeros((preds.shape[0], self.maxk), dtype=torch.bool)
for i in range(preds.shape[0]):
cur_label = labels[i].tolist()
if self.eos_token in cur_label:
eos_pos = cur_label.index(self.eos_token)
cur_label = cur_label[:eos_pos]
for j in range(self.maxk):
cur_pred = preds[i, j].tolist()
if cur_pred == cur_label:
pos_index[i, j] = True
break
return pos_index
def calculate_pos_index_no(self, preds, labels):
preds = preds.detach().cpu()
labels = labels.detach().cpu()
pos_index = torch.zeros((preds.shape[0], self.maxk), dtype=torch.bool)
for i in range(preds.shape[0]):
cur_label = labels[i].tolist()
if self.eos_token in cur_label:
eos_pos = cur_label.index(self.eos_token)
cur_label = cur_label[:eos_pos]
for j in range(self.maxk):
cur_pred = preds[i, j].tolist()
# 处理不同长度的预测
if isinstance(cur_pred, int):
cur_pred = [cur_pred]
elif len(cur_pred) > len(cur_label):
cur_pred = cur_pred[:len(cur_label)]
# 兼容SASRec的单token预测
if len(cur_pred) == 1 and len(cur_label) >= 1:
if cur_pred[0] == cur_label[0]:
pos_index[i, j] = True
break
elif cur_pred == cur_label:
pos_index[i, j] = True
break
return pos_index
def recall_at_k(self, pos_index, k):
return pos_index[:, :k].sum(dim=1).cpu().float()
def ndcg_at_k(self, pos_index, k):
# Assume only one ground truth item per example
ranks = torch.arange(1, pos_index.shape[-1] + 1).to(pos_index.device)
dcg = 1.0 / torch.log2(ranks + 1)
dcg = torch.where(pos_index, dcg, 0)
return dcg[:, :k].sum(dim=1).cpu().float()
def calculate_metrics(self, preds, labels):
results = {}
pos_index = self.calculate_pos_index(preds, labels)
for metric in self.config['metrics']:
for k in self.config['topk']:
results[f"{metric}@{k}"] = self.metric2func[metric](pos_index, k)
return results