forked from YangShihao-Twave/BertSimilarity
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimilarity.py
More file actions
74 lines (56 loc) · 2.58 KB
/
similarity.py
File metadata and controls
74 lines (56 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
def mean_pooling(model_output, attention_mask):
"""对 BERT 的 last_hidden_state 做平均池化,得到句向量"""
token_embeddings = model_output[0] # (batch, seq_len, hidden)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
return sum_embeddings / torch.clamp(sum_mask, min=1e-9)
def cosine_similarity(vec1, vec2):
"""计算余弦相似度"""
v1 = vec1 / (np.linalg.norm(vec1) + 1e-10)
v2 = vec2 / (np.linalg.norm(vec2) + 1e-10)
return float(np.dot(v1, v2))
def get_sentence_embedding(model, tokenizer, sentence, device):
"""编码一句话得到句向量"""
encoded_input = tokenizer(
sentence,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt"
).to(device)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
return sentence_embedding.cpu().numpy()[0]
def main():
parser = argparse.ArgumentParser(description="Sentence similarity with multilingual BERT")
parser.add_argument("--text1", type=str, required=True, help="第一段文本")
parser.add_argument("--text2", type=str, required=True, help="第二段文本")
parser.add_argument("--text3", type=str, required=True, help="第三段文本")
args = parser.parse_args()
model_name = "hotchpotch/japanese-reranker-cross-encoder-base-v1"
# 自动选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.to(device)
model.eval()
# 得到句向量
emb1 = get_sentence_embedding(model, tokenizer, args.text1, device)
emb2 = get_sentence_embedding(model, tokenizer, args.text2, device)
emb3 = get_sentence_embedding(model, tokenizer, args.text3, device)
# 计算余弦相似度
sim1 = cosine_similarity(emb1, emb2)
sim2 = cosine_similarity(emb1, emb3)
sim3 = cosine_similarity(emb2, emb3)
print(f"Similarity between Manual answer and Retrieval answer = {sim1:.4f}")
print(f"Similarity between Manual answer and LLM answer = {sim2:.4f}")
print(f"Similarity between Retrieval answer and LLM answer = {sim3:.4f}")
if __name__ == "__main__":
main()