Skip to content

Architecture Improvement: Residual part #47

@LoserCheems

Description

@LoserCheems

This paper 2409.19606 shows that, the problem of similar gradients between different layers occurs in deep transformer models.

We can try to add learnable parameters to the residual connection, scaling the state of the previous layer before connecting with the next layer.

class Residual(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))

    def forward(self, residual_states, hidden_states):
        return self.weight * residual_states + hidden_states

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}"

We can use it in decoder as follows.

# sequence transformation
residual = hidden_states
hidden_states = self.pre_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    cache_position=cache_position,
    position_embeddings=position_embeddings,
    **kwargs,
)
self_attn_weights = None
hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
hidden_states = self.pre_residual(residual, hidden_states)

# state transformation
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
hidden_states = self.post_residual(residual, hidden_states)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions