-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathtest.py
More file actions
81 lines (63 loc) · 2.76 KB
/
test.py
File metadata and controls
81 lines (63 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
os.environ["CHAINER_TYPE_CHECK"] = "0"
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
from base.corpus import *
import pickle
import argparse
import nltk
import random
import numpy as np
import chainer
from chainer import cuda, optimizers, serializers,training,reporter,iterators
from chainer.training.updaters import MultiprocessParallelUpdater
from chainer.iterators import MultiprocessIterator
import codecs
import collections
from att_rec import *
from repeat_net import *
from base.recsys import *
print 'Chainer Version: '+chainer.__version__
print 'Cupy Version: ', cuda.cupy.__version__
print 'CuDNN Version: ',chainer._cudnn_version
def evaluates_mode(valid_dataset,test_dataset,batch_size,model):
evaluate_mode(valid_dataset, batch_size, model)
evaluate_mode(test_dataset, batch_size, model)
def evaluate_mode(test_dataset, batch_size, model):
pointer = 0
eval_results = []
while pointer < len(test_dataset):
end = len(test_dataset) if pointer + batch_size >= len(test_dataset) else pointer + batch_size
batch = test_dataset[pointer:end]
input_list=[b[0] for b in batch]
results = model.predict(input_list)
for i in range(len(batch)):
eval_results.append([results[1][i], batch[i][1][0] in batch[i][0]])
pointer += batch_size
return accuracy(eval_results)
def accuracy(eval_results):
correct=0
for one in eval_results:
if one[1] and one[0][1].data>one[0][0].data:
correct+=1
print float(correct)/len(eval_results)
return float(correct)/len(eval_results)
if __name__ == '__main__':
with chainer.using_config('cudnn_deterministic', True):
with chainer.using_config('use_cudnn', 'auto'):
print chainer.config.show()
device = 3
item2id, id2item = load_item(file='data/lastfm/lastfm_items.artist.txt')
test_dataset = SessionCorpus(file_path='data/lastfm/lastfm_test.repeat.artist.txt', item2id=item2id).load()
test_batchsize = 1024
# model=RepeatNet(len(item2id), embed_size=100, hidden_size=100,joint_train=False)
model = AttRec(len(item2id), embed_size=100, hidden_size=100)
recsys = RecSys(model, 20)
serializers.load_npz('model/att_rec_100_100_lastfm.model.26274.npz', recsys)
# evaluate_mode(test_dataset,test_batchsize,model)
# recsys = RecSys(model, 10)
# serializers.load_npz('model/att_rec.model.140033.npz', recsys)
print evaluate(test_dataset, test_batchsize, recsys, prefix='test')
test_dataset = SessionCorpus(file_path='data/lastfm/lastfm_test.nonrepeat.artist.txt', item2id=item2id).load()
print evaluate(test_dataset, test_batchsize, recsys, prefix='test')