This paper 2409.19606 shows that, the problem of similar gradients between different layers occurs in deep transformer models.
We can try to add learnable parameters to the residual connection, scaling the state of the previous layer before connecting with the next layer.
class Residual(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward(self, residual_states, hidden_states):
return self.weight * residual_states + hidden_states
def extra_repr(self):
return f"{tuple(self.weight.shape)}"
We can use it in decoder as follows.
# sequence transformation
residual = hidden_states
hidden_states = self.pre_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
self_attn_weights = None
hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
hidden_states = self.pre_residual(residual, hidden_states)
# state transformation
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
hidden_states = self.post_residual(residual, hidden_states)
This paper 2409.19606 shows that, the problem of similar gradients between different layers occurs in deep transformer models.
We can try to add learnable parameters to the residual connection, scaling the state of the previous layer before connecting with the next layer.
We can use it in decoder as follows.