-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathrepeat_net.py
More file actions
89 lines (74 loc) · 3.02 KB
/
repeat_net.py
File metadata and controls
89 lines (74 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding:utf-8 -*-
import os
os.environ["CHAINER_TYPE_CHECK"] = "0"
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda, Variable
from base.encoder import *
from base.decoder import *
from base.utils import *
from base.function import *
class NoAttRepeatNet(chainer.Chain):
def __init__(self, item_size,embed_size, hidden_size,joint_train=False):
self.joint_train=joint_train
super(NoAttRepeatNet, self).__init__(
enc=NStepGRUEncoder(item_size,embed_size, hidden_size),
dec=NoAttReDecoder(item_size, hidden_size),
)
def predict(self,input_list):
x_enable = chainer.Variable(self.xp.array(mask(input_list)))
batch_last_h, batch_seq_h = self.enc(input_list, x_enable)
p_r, p_e, p = self.dec(batch_last_h, input_list, batch_seq_h, x_enable)
return p_r + p_e, p
def train(self,input_list,output_list):
predicts, p = self.predict(input_list)
slices = self.xp.zeros(predicts.shape, dtype=self.xp.int32) > 0
if self.joint_train:
p_slices = self.xp.zeros(p.shape, dtype=self.xp.int32) > 0
for i, v in enumerate(output_list):
slices[i, v] = True
if self.joint_train:
if v in input_list[i]:
p_slices[i, 1] = True
else:
p_slices[i, 0] = True
loss = -F.sum(F.log(F.get_item(predicts, slices))) / len(input_list)
if self.joint_train:
p_loss = -F.sum(F.log(F.get_item(p, p_slices))) / len(input_list)
if self.joint_train:
return loss, p_loss
else:
return loss
class RepeatNet(chainer.Chain):
def __init__(self, item_size,embed_size, hidden_size,joint_train=False):
self.joint_train = joint_train
super(RepeatNet, self).__init__(
enc=NStepGRUEncoder(item_size,embed_size, hidden_size),
dec=AttReDecoder(item_size, hidden_size),
)
def predict(self,input_list):
x_enable = chainer.Variable(self.xp.array(mask(input_list)))
batch_last_h, batch_seq_h = self.enc(input_list, x_enable)
p_r,p_e,p= self.dec(batch_last_h, input_list,batch_seq_h, x_enable)
return p_r+p_e,p
def train(self,input_list,output_list):
predicts,p=self.predict(input_list)
slices=self.xp.zeros(predicts.shape, dtype=self.xp.int32)>0
if self.joint_train:
p_slices = self.xp.zeros(p.shape, dtype=self.xp.int32) > 0
for i, v in enumerate(output_list):
slices[i,v]=True
if self.joint_train:
if v in input_list[i]:
p_slices[i,1]=True
else:
p_slices[i, 0] = True
loss=-F.sum(F.log(F.get_item(predicts,slices)))/len(input_list)
if self.joint_train:
p_loss=-F.sum(F.log(F.get_item(p,p_slices)))/len(input_list)
if self.joint_train:
return loss,p_loss
else:
return loss