-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtransformer_helpers.py
More file actions
171 lines (156 loc) · 4.63 KB
/
Copy pathtransformer_helpers.py
File metadata and controls
171 lines (156 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch.nn
from layers.transformer import Transformer
from layers.transformer import GatedTransformer
from layers.transformer.transformer import TransformerDecoderWithLayer
from models import TransformerEncDecModel, TransformerDecModel
from interfaces import (
TransformerEncDecInterface,
TransformerDecOnlyInterface,
TransformerLMInterface,
TransformerPrefixLMInterface,
TransformerDecoderLMInterface,
TransformerEncoderCLSInterface,
TransformerMLMInterface,
)
from models.transformer_lm import TransformerLM
from models.transformer_dec import TransformerDecoderLM
from models.transformer_enc import TransformerEncoderCLS
from models.transformer_mlm import TransformerMLM
def create_lm(
in_vocab_size,
vec_dim,
n_heads,
encoder_n_layers,
mode="enc_dec",
use_pos_embeddig=True,
pos_scale=1,
dropout=0.1,
tied_embedding=False,
gated_model=False,
) -> torch.nn.Module:
args = dict(embedding_init="xavier", scale_mode="opennmt")
if mode == "enc_dec":
# breakpoint()
return TransformerLM(
in_vocab_size,
vec_dim,
n_heads,
num_encoder_layers=encoder_n_layers,
pos_scale=pos_scale,
transformer=Transformer if not gated_model else GatedTransformer,
dropout=dropout,
tied_embedding=tied_embedding,
**args,
)
elif mode == "dec":
return TransformerDecoderLM(
in_vocab_size,
vec_dim,
n_heads,
nlayers=encoder_n_layers,
use_pos_embeddig=use_pos_embeddig,
**args,
)
def create_cls(
in_vocab_size,
out_vocab_size,
vec_dim,
n_heads,
encoder_n_layers,
use_pos_embeddig=True,
causal_encoder=False,
**args
) -> torch.nn.Module:
args = dict(embedding_init="xavier", scale_mode="opennmt")
return TransformerEncoderCLS(
in_vocab_size,
out_vocab_size,
state_size=vec_dim,
nhead=n_heads,
nlayers=encoder_n_layers,
use_pos_embeddig=use_pos_embeddig,
causal_encoder=causal_encoder,
**args,
)
def create_mlm(
in_vocab_size,
vec_dim,
n_heads,
encoder_n_layers,
use_pos_embedding=True,
causal_encoder=False,
) -> torch.nn.Module:
args = dict(embedding_init="xavier", scale_mode="opennmt")
return TransformerMLM(
n_input_tokens=in_vocab_size,
state_size=vec_dim,
nhead=n_heads,
nlayers=encoder_n_layers,
use_pos_embedding=use_pos_embedding,
causal_encoder=causal_encoder,
**args,
)
def create_model(
in_vocab_size,
out_vocab_size,
vec_dim,
n_heads,
encoder_n_layers,
decoder_n_layers,
is_null_encoder=False,
dropout=0.1,
tied_embedding=True,
mode="enc_dec",
) -> torch.nn.Module:
args = dict(embedding_init="xavier", scale_mode="opennmt", mode=mode)
if is_null_encoder:
return TransformerDecModel(
in_vocab_size,
out_vocab_size,
vec_dim,
n_heads,
num_encoder_layers=encoder_n_layers,
num_decoder_layers=decoder_n_layers,
tied_embedding=tied_embedding,
dropout=dropout,
**args,
)
else:
return TransformerEncDecModel(
in_vocab_size,
out_vocab_size,
vec_dim,
n_heads,
num_encoder_layers=encoder_n_layers,
num_decoder_layers=decoder_n_layers,
tied_embedding=tied_embedding,
dropout=dropout,
**args,
)
def create_model_interface(
model,
label_smoothing=0.0,
is_null_encoder=False,
is_lm=False,
is_cls=False,
is_prefix_lm=False,
):
if is_lm:
if not is_null_encoder:
if not is_prefix_lm:
return TransformerLMInterface(model, label_smoothing=label_smoothing)
else:
return TransformerPrefixLMInterface(
model, label_smoothing=label_smoothing
)
else:
return TransformerDecoderLMInterface(model, label_smoothing=label_smoothing)
elif is_cls:
return TransformerEncoderCLSInterface(model, label_smoothing=label_smoothing)
elif is_null_encoder:
return TransformerDecOnlyInterface(model, label_smoothing=label_smoothing)
else:
return TransformerEncDecInterface(model, label_smoothing=label_smoothing)
def create_mlm_interface(model, label_smoothing=0.0):
return TransformerMLMInterface(model, label_smoothing=label_smoothing)
#### Similar interfaces for pretrained models...