Skip to content

Commit 48b139e

Browse files
titaiwangmsmansiag05
authored andcommitted
[ONNX] Support enable_gqa when dropout is non-zero (pytorch#162771)
Fixes pytorch#162258 Related to microsoft/onnxscript#2558 Pull Request resolved: pytorch#162771 Approved by: https://github.com/justinchuby
1 parent 9bd1a5b commit 48b139e

2 files changed

Lines changed: 94 additions & 0 deletions

File tree

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,36 @@ def forward(self, x):
806806
# Test with reference evaluator because ORT does not support the op as of version 1.22
807807
onnx_testing.assert_onnx_program(onnx_program, backend="reference")
808808

809+
def test_enable_gqa_in_attention_23_with_dropout(self):
810+
class Model(torch.nn.Module):
811+
def forward(self, q, k, v):
812+
return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable
813+
q, k, v, enable_gqa=True, dropout_p=0.1
814+
)
815+
816+
model = Model()
817+
818+
query = torch.randn(2, 4, 8, 16)
819+
key = torch.randn(2, 2, 8, 16)
820+
value = torch.randn(2, 2, 8, 16)
821+
822+
onnx_program = self.export(
823+
model,
824+
(
825+
query,
826+
key,
827+
value,
828+
),
829+
opset_version=23,
830+
)
831+
# opset23 only uses manually gqa path when dropout is enabled,
832+
# and dropout makes the output non-deterministic,
833+
# so we check for the presence of the ops used in that path.
834+
all_ops = [node.op_type for node in onnx_program.model.graph]
835+
self.assertIn("Unsqueeze", all_ops)
836+
self.assertIn("Expand", all_ops)
837+
self.assertIn("Reshape", all_ops)
838+
809839

810840
if __name__ == "__main__":
811841
common_utils.run_tests()

torch/onnx/_internal/exporter/_torchlib/ops/nn.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ def aten_scaled_dot_product_attention_23(
170170
if is_causal:
171171
attn_mask = _causal_attention_mask(query, key, op23)
172172

173+
if enable_gqa:
174+
key, value = _attention_repeat_kv_for_group_query(query, key, value, op23)
175+
173176
if attn_mask is None:
174177
return _aten_scaled_dot_product_attention_no_mask_onnx(
175178
query, key, value, scale, dropout_p, op23
@@ -180,6 +183,67 @@ def aten_scaled_dot_product_attention_23(
180183
)
181184

182185

186+
def _attention_repeat_kv_for_group_query(
187+
query: TFloat, key: TFloat, value: TFloat, op: Opset
188+
) -> tuple[TFloat, TFloat]:
189+
"""Expand key and value for group query attention.
190+
191+
repeat_interleave is applied on key and value to match the number of heads in query.
192+
193+
Args:
194+
query: Tensor of shape [B, q_num_heads, q_S, E]
195+
key: Tensor of shape [B, k_num_heads, kv_S, E]
196+
value: Tensor of shape [B, v_num_heads, kv_S, E]
197+
198+
Returns:
199+
Tuple of (expanded_key, expanded_value) where:
200+
- expanded_key: Tensor of shape [B, q_num_heads, kv_S, E]
201+
- expanded_value: Tensor of shape [B, q_num_heads, kv_S, E
202+
"""
203+
204+
assert (
205+
query.shape[1] > key.shape[1] == value.shape[1]
206+
and query.shape[1] % key.shape[1] == 0
207+
), (
208+
"SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0"
209+
)
210+
211+
# NOTE: QKV are expected to be 4D tensors
212+
213+
batch_size = op.Shape(query, start=0, end=1) # [B]
214+
q_num_heads = op.Shape(query, start=1, end=2) # [Hq]
215+
kv_num_heads = op.Shape(key, start=1, end=2) # [Hk]
216+
qk_head_size = op.Shape(key, start=3, end=4) # [Dk]
217+
v_head_size = op.Shape(value, start=3, end=4) # [Dv]
218+
new_kv_seq_len = op.Shape(key, start=2, end=3) # [T]
219+
220+
interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk
221+
two = op.Constant(value_int=2)
222+
k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk]
223+
v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv]
224+
225+
k_expand_shape = op.Concat(
226+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0
227+
)
228+
k_expand = op.Expand(k_unsqueezed, k_expand_shape)
229+
v_expand_shape = op.Concat(
230+
batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0
231+
)
232+
v_expand = op.Expand(v_unsqueezed, v_expand_shape)
233+
234+
k_attention_shape = op.Concat(
235+
batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0
236+
)
237+
v_attention_shape = op.Concat(
238+
batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0
239+
)
240+
241+
expanded_key = op.Reshape(k_expand, k_attention_shape)
242+
expanded_value = op.Reshape(v_expand, v_attention_shape)
243+
244+
return expanded_key, expanded_value
245+
246+
183247
def _attention_scale(query: TFloat, op: Opset) -> TFloat:
184248
"""Calculate the scale factor for the attention result.
185249

0 commit comments

Comments
 (0)