diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 88b5bf807e..1a31c9eac8 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1741,6 +1741,64 @@ def _attention_scale(query: TFloat) -> TFloat: return scale +def _attention_repeat_kv_for_group_query( + query: TFloat, key: TFloat, value: TFloat +) -> Tuple[TFloat, TFloat]: + """Expand key and value for group query attention. + + repeat_interleave is applied on key and value to match the number of heads in query. + + Args: + query: Tensor of shape [B, q_num_heads, q_S, E] + key: Tensor of shape [B, k_num_heads, kv_S, E] + value: Tensor of shape [B, v_num_heads, kv_S, E] + + Returns: + Tuple of (expanded_key, expanded_value) where: + - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E] + - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E + """ + + assert ( + query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + + # NOTE: QKV are expected to be 4D tensors + + batch_size = op.Shape(query, start=0, end=1) # [B] + q_num_heads = op.Shape(query, start=1, end=2) # [Hq] + kv_num_heads = op.Shape(key, start=1, end=2) # [Hk] + qk_head_size = op.Shape(key, start=3, end=4) # [Dk] + v_head_size = op.Shape(value, start=3, end=4) # [Dv] + new_kv_seq_len = op.Shape(key, start=2, end=3) # [T] + + interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk + two = op.Constant(value_int=2) + k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk] + v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv] + + k_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0 + ) + k_expand = op.Expand(k_unsqueezed, k_expand_shape) + v_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0 + ) + v_expand = op.Expand(v_unsqueezed, v_expand_shape) + + k_attention_shape = op.Concat( + batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0 + ) + v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0) + + expanded_key = op.Reshape(k_expand, k_attention_shape) + expanded_value = op.Reshape(v_expand, v_attention_shape) + + return expanded_key, expanded_value + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention( query: TFloat, @@ -1772,8 +1830,8 @@ def aten_scaled_dot_product_attention( "is_causal and attn_mask cannot be set at the same time" ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" ) # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -1784,6 +1842,13 @@ def aten_scaled_dot_product_attention( if is_causal: attn_mask = _causal_attention_mask(query, key) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p @@ -1981,9 +2046,8 @@ def aten_scaled_dot_product_attention_bool_mask( assert (not is_causal) or (is_causal and attn_mask is None), ( "is_causal and attn_mask cannot be set at the same time" ) - - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" ) if scale is None: @@ -1997,6 +2061,9 @@ def aten_scaled_dot_product_attention_bool_mask( query, key, value, attn_mask, scale, dropout_p ) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 253637ccd2..40a7dea180 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -174,6 +174,36 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_enable_gqa_in_attention(self): + class Model(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + enable_gqa=True, + ) + + model = Model() + + query = torch.randn(2, 4, 8, 16) + key = torch.randn(2, 2, 8, 16) + value = torch.randn(2, 2, 8, 16) + + onnx_program = torch.onnx.export( + model, + ( + query, + key, + value, + ), + input_names=["query", "key", "value"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 646a5133fa..7064419ac7 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1908,6 +1908,12 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", + ) + .xfail( + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", @@ -1959,6 +1965,12 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", + ) + .xfail( + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default",