Skip to content

Commit 65ebb01

Browse files
authored
Fix conversion when enable_gqa is False and dimension are different (#2763)
1 parent 41f5019 commit 65ebb01

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,10 +1837,6 @@ def aten_scaled_dot_product_attention(
18371837

18381838
if enable_gqa:
18391839
key, value = _attention_repeat_kv_for_group_query(query, key, value)
1840-
else:
1841-
assert query.shape[1] == key.shape[1] == value.shape[1], (
1842-
"SDPA (MHA) requires q_num_heads = kv_num_heads"
1843-
)
18441840

18451841
if attn_mask is None:
18461842
return _aten_scaled_dot_product_attention_no_mask_onnx(

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,37 @@ def forward(self, q, k, v):
228228
)
229229
_testing.assert_onnx_program(onnx_program)
230230

231+
def test_optional_enable_gqa_in_attention(self):
232+
class Model(torch.nn.Module):
233+
def forward(self, q, k, v):
234+
return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
235+
q,
236+
k,
237+
v,
238+
)
239+
240+
model = Model()
241+
242+
# scaled_dot_product_attention works even if query.shape[1] != key.shape[1]
243+
# due to broadcasting
244+
query = torch.randn(2, 1, 8, 16)
245+
key = torch.randn(2, 2, 8, 16)
246+
value = torch.randn(2, 2, 8, 16)
247+
248+
onnx_program = torch.onnx.export(
249+
model,
250+
(
251+
query,
252+
key,
253+
value,
254+
),
255+
input_names=["query", "key", "value"],
256+
output_names=["output"],
257+
opset_version=18,
258+
dynamo=True,
259+
)
260+
_testing.assert_onnx_program(onnx_program)
261+
231262
def test_bitwise_and_scalar(self):
232263
class Model(torch.nn.Module):
233264
def forward(self, x):

0 commit comments

Comments
 (0)