Skip to content

Commit 1a024e8

Browse files
gramalingamCopilot
andcommitted
Add unit tests for MultiHeadAttention fusion rules (mha.py)
5 structural tests covering 4 rule variants: - Basic MHA with key transposed (BHSd format) - Basic MHA with key not transposed (BSHd format) - MHA with past key/value (has_past_present=True, 3 outputs) - MHA with RotaryEmbedding on Q and K - Rank-2 query shape rejection (negative) Tests are structural-only (no ORT run) since the pattern requires internal SDPA nodes (ai.onnxruntime._fusion) that ORT cannot execute. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 87be039 commit 1a024e8

File tree

1 file changed

+259
-0
lines changed

1 file changed

+259
-0
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Unit tests for MultiHeadAttention fusion rules (mha.py).
5+
6+
The MHA rule matches the pattern:
7+
Q/K/V → Reshape → Transpose → [RotaryEmbedding] → [Concat past] → SDPA → Transpose → Reshape
8+
and fuses it into a single MultiHeadAttention contrib op.
9+
10+
These are structural tests (no ORT run) because the pattern requires internal
11+
SDPA nodes (ai.onnxruntime._fusion domain) which ORT cannot execute directly.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import unittest
17+
18+
import numpy as np
19+
import onnx_ir as ir
20+
21+
from onnxscript import FLOAT, script, values
22+
from onnxscript import opset18 as op
23+
from onnxscript.optimizer import optimize
24+
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
25+
26+
# Custom opsets
27+
msft_op = values.Opset("com.microsoft", 1)
28+
fusion_op = values.Opset("ai.onnxruntime._fusion", 1)
29+
30+
_B, _S, _H, _Dh = 2, 8, 4, 4
31+
_D = _H * _Dh # 16
32+
_Skv = 8
33+
_Spast = 4
34+
35+
_RESHAPE_Q = ir.tensor(np.array([0, 0, _H, _Dh], dtype=np.int64))
36+
_RESHAPE_K = ir.tensor(np.array([0, 0, _H, _Dh], dtype=np.int64))
37+
_RESHAPE_V = ir.tensor(np.array([0, 0, _H, _Dh], dtype=np.int64))
38+
_RESHAPE_OUT = ir.tensor(np.array([0, 0, _D], dtype=np.int64))
39+
40+
41+
# --- Simplest: no rotary, no past, key transposed ---
42+
43+
44+
@script()
45+
def _mha_basic_key_transposed(query_BSD, key_BSD, value_BSD):
46+
q_shape = op.Constant(value=_RESHAPE_Q)
47+
q_4d = op.Reshape(query_BSD, q_shape)
48+
q_BHSDh = op.Transpose(q_4d, perm=[0, 2, 1, 3])
49+
50+
k_shape = op.Constant(value=_RESHAPE_K)
51+
k_4d = op.Reshape(key_BSD, k_shape)
52+
k_BHSDh = op.Transpose(k_4d, perm=[0, 2, 1, 3])
53+
54+
v_shape = op.Constant(value=_RESHAPE_V)
55+
v_4d = op.Reshape(value_BSD, v_shape)
56+
v_BHSDh = op.Transpose(v_4d, perm=[0, 2, 1, 3])
57+
58+
sdpa_out = fusion_op.SDPA(q_BHSDh, k_BHSDh, v_BHSDh, key_format="BHSd")
59+
60+
att_transposed = op.Transpose(sdpa_out, perm=[0, 2, 1, 3])
61+
out_shape = op.Constant(value=_RESHAPE_OUT)
62+
return op.Reshape(att_transposed, out_shape)
63+
64+
65+
# --- No rotary, no past, key NOT transposed ---
66+
67+
68+
@script()
69+
def _mha_basic_key_not_transposed(query_BSD, key_BSD, value_BSD):
70+
q_shape = op.Constant(value=_RESHAPE_Q)
71+
q_4d = op.Reshape(query_BSD, q_shape)
72+
q_BHSDh = op.Transpose(q_4d, perm=[0, 2, 1, 3])
73+
74+
k_shape = op.Constant(value=_RESHAPE_K)
75+
k_4d = op.Reshape(key_BSD, k_shape)
76+
# Key is NOT transposed — stays in BSHd format
77+
78+
v_shape = op.Constant(value=_RESHAPE_V)
79+
v_4d = op.Reshape(value_BSD, v_shape)
80+
v_BHSDh = op.Transpose(v_4d, perm=[0, 2, 1, 3])
81+
82+
sdpa_out = fusion_op.SDPA(q_BHSDh, k_4d, v_BHSDh, key_format="BSHd")
83+
84+
att_transposed = op.Transpose(sdpa_out, perm=[0, 2, 1, 3])
85+
out_shape = op.Constant(value=_RESHAPE_OUT)
86+
return op.Reshape(att_transposed, out_shape)
87+
88+
89+
# --- With past key/value (has_past_present=True) ---
90+
91+
92+
@script()
93+
def _mha_with_past(query_BSD, key_BSD, value_BSD, past_key, past_value):
94+
q_shape = op.Constant(value=_RESHAPE_Q)
95+
q_4d = op.Reshape(query_BSD, q_shape)
96+
q_BHSDh = op.Transpose(q_4d, perm=[0, 2, 1, 3])
97+
98+
k_shape = op.Constant(value=_RESHAPE_K)
99+
k_4d = op.Reshape(key_BSD, k_shape)
100+
k_BHSDh = op.Transpose(k_4d, perm=[0, 2, 1, 3])
101+
102+
v_shape = op.Constant(value=_RESHAPE_V)
103+
v_4d = op.Reshape(value_BSD, v_shape)
104+
v_BHSDh = op.Transpose(v_4d, perm=[0, 2, 1, 3])
105+
106+
# Concat with past
107+
key_seq = op.Concat(past_key, k_BHSDh, axis=-2)
108+
value_seq = op.Concat(past_value, v_BHSDh, axis=-2)
109+
110+
sdpa_out = fusion_op.SDPA(q_BHSDh, key_seq, value_seq, key_format="BHSd")
111+
112+
att_transposed = op.Transpose(sdpa_out, perm=[0, 2, 1, 3])
113+
out_shape = op.Constant(value=_RESHAPE_OUT)
114+
attention = op.Reshape(att_transposed, out_shape)
115+
return attention, key_seq, value_seq
116+
117+
118+
# --- With rotary embedding (no past) ---
119+
120+
121+
@script()
122+
def _mha_with_rotary(query_BSD, key_BSD, value_BSD, position_ids, cos, sin):
123+
q_shape = op.Constant(value=_RESHAPE_Q)
124+
q_4d = op.Reshape(query_BSD, q_shape)
125+
q_BHSDh = op.Transpose(q_4d, perm=[0, 2, 1, 3])
126+
127+
k_shape = op.Constant(value=_RESHAPE_K)
128+
k_4d = op.Reshape(key_BSD, k_shape)
129+
k_BHSDh = op.Transpose(k_4d, perm=[0, 2, 1, 3])
130+
131+
v_shape = op.Constant(value=_RESHAPE_V)
132+
v_4d = op.Reshape(value_BSD, v_shape)
133+
v_BHSDh = op.Transpose(v_4d, perm=[0, 2, 1, 3])
134+
135+
q_emb = msft_op.RotaryEmbedding(q_BHSDh, position_ids, cos, sin)
136+
k_emb = msft_op.RotaryEmbedding(k_BHSDh, position_ids, cos, sin)
137+
138+
sdpa_out = fusion_op.SDPA(q_emb, k_emb, v_BHSDh, key_format="BHSd")
139+
140+
att_transposed = op.Transpose(sdpa_out, perm=[0, 2, 1, 3])
141+
out_shape = op.Constant(value=_RESHAPE_OUT)
142+
return op.Reshape(att_transposed, out_shape)
143+
144+
145+
class MultiHeadAttentionFusionTest(unittest.TestCase):
146+
"""Structural unit tests for MultiHeadAttention fusion rules."""
147+
148+
def _build(self, script_fn, input_types, output_types) -> ir.Model:
149+
model_proto = script_fn.to_model_proto(
150+
input_types=input_types, output_types=output_types
151+
)
152+
model = ir.serde.deserialize_model(model_proto)
153+
optimize(model)
154+
return model
155+
156+
def _apply(self, model: ir.Model) -> int:
157+
count = fuse_mha1(model)
158+
count += fuse_mha2(model)
159+
return count
160+
161+
def _count_op(self, model: ir.Model, op_type: str, domain: str = "") -> int:
162+
return sum(1 for n in model.graph if n.op_type == op_type and n.domain == domain)
163+
164+
def _get_mha_node(self, model: ir.Model) -> ir.Node | None:
165+
for node in model.graph:
166+
if node.op_type == "MultiHeadAttention" and node.domain == "com.microsoft":
167+
return node
168+
return None
169+
170+
_3D = (FLOAT["B", "S", _D],) * 3
171+
_OUT_1 = (FLOAT["B", "S", _D],)
172+
173+
# --- Positive tests ---
174+
175+
def test_basic_key_transposed(self):
176+
"""Simplest MHA: no rotary, no past, key transposed → fuses."""
177+
model = self._build(_mha_basic_key_transposed, self._3D, self._OUT_1)
178+
count = self._apply(model)
179+
self.assertEqual(count, 1)
180+
self.assertEqual(self._count_op(model, "MultiHeadAttention", "com.microsoft"), 1)
181+
self.assertEqual(self._count_op(model, "SDPA", "ai.onnxruntime._fusion"), 0)
182+
mha = self._get_mha_node(model)
183+
self.assertIsNotNone(mha)
184+
self.assertEqual(mha.attributes.get_int("num_heads", 0), _H)
185+
186+
def test_basic_key_not_transposed(self):
187+
"""Key not transposed (BSHd format) → still fuses."""
188+
model = self._build(_mha_basic_key_not_transposed, self._3D, self._OUT_1)
189+
count = self._apply(model)
190+
self.assertEqual(count, 1)
191+
self.assertEqual(self._count_op(model, "MultiHeadAttention", "com.microsoft"), 1)
192+
193+
def test_with_past_key_value(self):
194+
"""Past key/value Concats → fuses with 3 outputs (attention, present_k, present_v)."""
195+
model = self._build(
196+
_mha_with_past,
197+
input_types=[
198+
FLOAT["B", "S", _D],
199+
FLOAT["B", "S", _D],
200+
FLOAT["B", "S", _D],
201+
FLOAT["B", _H, "Spast", _Dh],
202+
FLOAT["B", _H, "Spast", _Dh],
203+
],
204+
output_types=[
205+
FLOAT["B", "S", _D],
206+
FLOAT["B", _H, "St", _Dh],
207+
FLOAT["B", _H, "St", _Dh],
208+
],
209+
)
210+
count = self._apply(model)
211+
self.assertEqual(count, 1)
212+
mha = self._get_mha_node(model)
213+
self.assertIsNotNone(mha)
214+
self.assertEqual(len(mha.outputs), 3)
215+
# past_key and past_value should be connected (inputs 6, 7)
216+
self.assertIsNotNone(mha.inputs[6])
217+
self.assertIsNotNone(mha.inputs[7])
218+
219+
def test_with_rotary_embedding(self):
220+
"""RotaryEmbedding on Q and K before SDPA → fuses."""
221+
model = self._build(
222+
_mha_with_rotary,
223+
input_types=[
224+
FLOAT["B", "S", _D],
225+
FLOAT["B", "S", _D],
226+
FLOAT["B", "S", _D],
227+
FLOAT["B", "S"], # position_ids
228+
FLOAT["S", _Dh], # cos
229+
FLOAT["S", _Dh], # sin
230+
],
231+
output_types=[FLOAT["B", "S", _D]],
232+
)
233+
count = self._apply(model)
234+
self.assertEqual(count, 1)
235+
mha = self._get_mha_node(model)
236+
self.assertIsNotNone(mha)
237+
# Rotary should be moved to operate on BSD-format inputs in the rewrite
238+
rotary_count = self._count_op(model, "RotaryEmbedding", "com.microsoft")
239+
self.assertGreater(rotary_count, 0)
240+
241+
# --- Negative test ---
242+
243+
def test_rank2_query_no_fusion(self):
244+
"""Query with rank 2 [S, D] instead of [B, S, D] → shape check rejects."""
245+
model = self._build(
246+
_mha_basic_key_transposed,
247+
input_types=[
248+
FLOAT["S", _D],
249+
FLOAT["B", "S", _D],
250+
FLOAT["B", "S", _D],
251+
],
252+
output_types=[FLOAT["B", "S", _D]],
253+
)
254+
count = self._apply(model)
255+
self.assertEqual(count, 0)
256+
257+
258+
if __name__ == "__main__":
259+
unittest.main()

0 commit comments

Comments
 (0)