-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_and_eval.py
More file actions
104 lines (94 loc) · 2.98 KB
/
load_and_eval.py
File metadata and controls
104 lines (94 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import time
import copy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import warnings
import torch
USE_CUDA = torch.cuda.is_available()
import get_data
from model import BiLSTM_Match
from model import LSTM_Match
"""
embedding_dim=400
hidden_dim=256
vocab_size=51158
target=1
Batchsize=64
stringlen=25
Epoch=20
lr=0.1
"""
embedding_dim=400
hidden_dim=512
vocab_size=51158
target=1
Batchsize=16
stringlen=25
Epoch=20
lr=0.1
USE_Bi=True
if USE_Bi:
print("Using BiLSTM")
model = BiLSTM_Match(embedding_dim, hidden_dim, vocab_size, target, Batchsize, stringlen)
model_path = "./Model/BiLSTMmodel.pth"
else:
print("Using LSTM")
model = LSTM_Match(embedding_dim, hidden_dim, vocab_size,target,Batchsize,stringlen)
model_path = "./Model/LSTMmodel.pth"
model.load_state_dict(torch.load(model_path))
model=model.cuda()
print(model)
texta,textb,labels,evala,evalb,evallabels=get_data.train_data(stringlen)
resulta,resultb=get_data.result_data(stringlen)
if USE_CUDA:
texta = texta.cuda()
textb= textb.cuda()
labels= labels.cuda()
evala= evala.cuda()
evalb= evalb.cuda()
evallabels= evallabels.cuda()
resulta=resulta.cuda()
resultb=resultb.cuda()
def eval(net,eval_dataa,eval_datab, eval_label,batch_size):
net.eval()
dataset = torch.utils.data.TensorDataset(eval_dataa,eval_datab, eval_label)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False)
total=0
correct=0
statea = None
stateb = None
with torch.no_grad():
for XA, XB , y in train_iter:
XA = XA.long()
XB = XB.long()
if XA.size(0)!= batch_size:
break
if statea is not None:
if isinstance(statea, tuple): # LSTM, state:(h, c)
statea = (statea[0].detach(), statea[1].detach())
else:
statea = statea.detach()
if stateb is not None:
if isinstance(stateb, tuple): # LSTM, state:(h, c)
stateb = (stateb[0].detach(), stateb[1].detach())
else:
stateb = stateb.detach()
(output, statea, stateb) = net(XA, XB, statea, stateb, False)
total += XA.size(0)
result = output.detach().cpu().numpy().tolist()
result = [1 if i[0] > 0.5 else 0 for i in result]
answer = y.cpu().numpy().tolist()
for i in range(len(answer)):
# print(answer[i][0]," ",result[i])
if answer[i][0] == result[i]:
correct += 1
s = (((1.0 * correct) / total))
print(correct,"/" , total, "TestAcc: ", s)
return s
eval(model,evala,evalb,evallabels,Batchsize)