This paper 2502.11089 show that, we can assign different self-attention mods to different tokens of a sequence, through a routing/function strategy or other methods.
Here's an example of dynamic mask attention:
def flex_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1
def causal_mod(score, batch, head, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score = score + causal_mask[batch][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch][head][0][0]
return score
def dynamic_mod(score, batch, head, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score = score + causal_mask[batch][head][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch][head][0][0]
return score
mask_mod = causal_mod if is_causal or module.training else dynamic_mod
attn_output, attention_weights = flex_attention(
query=query,
key=key,
value=value,
score_mod=mask_mod,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attention_weights
This paper 2502.11089 show that, we can assign
different self-attention modsto different tokens of a sequence, through a routing/function strategy or other methods.Here's an example of dynamic mask attention: