Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat:
return scale


def _attention_repeat_kv_for_group_query(
query: TFloat, key: TFloat, value: TFloat
) -> Tuple[TFloat, TFloat]:
"""Expand key and value for group query attention.

repeat_interleave is applied on key and value to match the number of heads in query.

Args:
query: Tensor of shape [B, q_num_heads, q_S, E]
key: Tensor of shape [B, k_num_heads, kv_S, E]
value: Tensor of shape [B, v_num_heads, kv_S, E]

Returns:
Tuple of (expanded_key, expanded_value) where:
- expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
- expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
"""

assert (
query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0
), (
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
)

# NOTE: QKV are expected to be 4D tensors

batch_size = op.Shape(query, start=0, end=1) # [B]
q_num_heads = op.Shape(query, start=1, end=2) # [Hq]
kv_num_heads = op.Shape(key, start=1, end=2) # [Hk]
qk_head_size = op.Shape(key, start=3, end=4) # [Dk]
v_head_size = op.Shape(value, start=3, end=4) # [Dv]
new_kv_seq_len = op.Shape(key, start=2, end=3) # [T]

interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk
two = op.Constant(value_int=2)
k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk]
v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv]

k_expand_shape = op.Concat(
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0
)
k_expand = op.Expand(k_unsqueezed, k_expand_shape)
v_expand_shape = op.Concat(
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0
)
v_expand = op.Expand(v_unsqueezed, v_expand_shape)

k_attention_shape = op.Concat(
batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0
)
v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0)

expanded_key = op.Reshape(k_expand, k_attention_shape)
expanded_value = op.Reshape(v_expand, v_attention_shape)

return expanded_key, expanded_value


@torch_op("aten::scaled_dot_product_attention", trace_only=True)
def aten_scaled_dot_product_attention(
query: TFloat,
Expand Down Expand Up @@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention(
"is_causal and attn_mask cannot be set at the same time"
)

assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
"only 4D query, key, and value are supported"
)

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
Expand All @@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention(
if is_causal:
attn_mask = _causal_attention_mask(query, key)

if enable_gqa:
key, value = _attention_repeat_kv_for_group_query(query, key, value)
else:
assert query.shape[1] == key.shape[1] == value.shape[1], (
"SDPA (MHA) requires q_num_heads = kv_num_heads"
)

if attn_mask is None:
return _aten_scaled_dot_product_attention_no_mask_onnx(
query, key, value, scale, dropout_p
Expand Down Expand Up @@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask(
assert (not is_causal) or (is_causal and attn_mask is None), (
"is_causal and attn_mask cannot be set at the same time"
)

assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, (
"only 4D query, key, and value are supported"
)

if scale is None:
Expand All @@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask(
query, key, value, attn_mask, scale, dropout_p
)

if enable_gqa:
key, value = _attention_repeat_kv_for_group_query(query, key, value)

if attn_mask is None:
return _aten_scaled_dot_product_attention_no_mask_onnx(
query, key, value, scale, dropout_p
Expand Down
30 changes: 30 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,36 @@ def forward(self, x):
)
_testing.assert_onnx_program(onnx_program)

def test_enable_gqa_in_attention(self):
class Model(torch.nn.Module):
def forward(self, q, k, v):
return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
q,
k,
v,
enable_gqa=True,
)

model = Model()

query = torch.randn(2, 4, 8, 16)
key = torch.randn(2, 2, 8, 16)
value = torch.randn(2, 2, 8, 16)

onnx_program = torch.onnx.export(
model,
(
query,
key,
value,
),
input_names=["query", "key", "value"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,6 +1908,12 @@ def _where_input_wrangler(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
matcher=lambda sample: len(sample.input.shape) != 4
or len(sample.args[0].shape) != 4
or len(sample.args[1].shape) != 4,
reason="torch sdpa is expected to pass in 4d q, k, and v.",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby @xadupre Let me know what you think on whether we should support only 4d QKV, or we should fully support whatever torch sdpa supports. Right now, it seems like QKV can have 3d or 4d or even q 3d and kv 4d in torch sdpa.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the ATen op? Does the nn function do preprocessing on the inputs before sending them to the kernel? We just need to support whatever the kernel supports

),
TorchLibOpInfo(
"ops.aten._scaled_dot_product_flash_attention",
Expand Down Expand Up @@ -1959,6 +1965,12 @@ def _where_input_wrangler(
dtypes=(torch.float16,),
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
test_class_name="TestOutputConsistencyFullGraph",
)
.xfail(
matcher=lambda sample: len(sample.input.shape) != 4
or len(sample.args[0].shape) != 4
or len(sample.args[1].shape) != 4,
reason="torch sdpa is expected to pass in 4d q, k, and v.",
),
TorchLibOpInfo(
"ops.aten.upsample_bilinear2d.default",
Expand Down
Loading