Skip to content

Architecture Improvement: Self-Attention part #48

@LoserCheems

Description

@LoserCheems

This paper 2502.11089 show that, we can assign different self-attention mods to different tokens of a sequence, through a routing/function strategy or other methods.

Here's an example of dynamic mask attention:

def flex_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    softcap: Optional[float] = None,
    head_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    causal_mask = attention_mask
    if attention_mask is not None:
        causal_mask = causal_mask[:, :, :, : key.shape[-2]]

    if is_causal is None:
        is_causal = causal_mask is None and query.shape[2] > 1

    def causal_mod(score, batch, head, q_idx, kv_idx):
        if softcap is not None:
            score = softcap * torch.tanh(score / softcap)
        if causal_mask is not None:
            score = score + causal_mask[batch][0][q_idx][kv_idx]
        if head_mask is not None:
            score = score + head_mask[batch][head][0][0]
        return score

    def dynamic_mod(score, batch, head, q_idx, kv_idx):
        if softcap is not None:
            score = softcap * torch.tanh(score / softcap)
        if causal_mask is not None:
            score = score + causal_mask[batch][head][q_idx][kv_idx]
        if head_mask is not None:
            score = score + head_mask[batch][head][0][0]
        return score

    mask_mod = causal_mod if is_causal or module.training else dynamic_mod

    attn_output, attention_weights = flex_attention(
        query=query,
        key=key,
        value=value,
        score_mod=mask_mod,
        enable_gqa=True,
        scale=scaling,
        # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
        # For simplification, we thus always return it as no additional computations are introduced.
        return_lse=True,
    )
    # lse is returned in float32
    attention_weights = attention_weights.to(value.dtype)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attention_weights

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions