Skip to content

Commit db3e52a

Browse files
authored
Fixed implementation error leading to only 1 MLP layer instead of 2 and normalization layer
1 parent 0f9778a commit db3e52a

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

transformerXL.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __call__(self, values_keys:jnp.ndarray, queries:jnp.ndarray, pos_embed:jnp.n
7575
out = self.dense1(out_attention_n)
7676
out = nn.activation.gelu(out)
7777
#out = nn.activation.relu(out)
78-
out = self.dense2(out_attention)
78+
out = self.dense2(out)
7979
if(self.gating):
8080
out= self.gate2(out,jax.nn.relu(out_attention))
8181
else:

0 commit comments

Comments
 (0)