Skip to content

Commit 03b0c7e

Browse files
gramalingamCopilot
andcommitted
Add unit tests for FuseMHAScale rule (mha_scale.py)
5 tests covering: - Scalar float constant scale → fused into MHA scale attribute - Integer scale constant → fused - Existing MHA scale attribute → combined with external scale - No Mul before MHA → no fusion (negative) - Dynamic (non-constant) scale input → no fusion (negative) All positive tests include numerical validation via ORT. Uses symbolic dims ("B", "S") in input/output types. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 95b5977 commit 03b0c7e

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""Unit tests for FuseMHAScale rule (mha_scale.py).
5+
6+
The rule detects Mul(query, constant_scale) before MultiHeadAttention and
7+
fuses the scaling into the MHA's ``scale`` attribute.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import math
13+
import unittest
14+
15+
import numpy as np
16+
import onnx_ir as ir
17+
18+
import onnxscript
19+
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
20+
from onnxscript import FLOAT, script
21+
from onnxscript import opset18 as op
22+
from onnxscript.optimizer import optimize
23+
from onnxscript.rewriter.ort_fusions.mha_scale import fuse_mha_scale
24+
25+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
26+
27+
_B, _S, _D = 2, 8, 16
28+
_NUM_HEADS = 4
29+
_HEAD_SIZE = _D // _NUM_HEADS
30+
_DEFAULT_SCALE = 1.0 / math.sqrt(_HEAD_SIZE)
31+
32+
# Pre-computed constant for use inside @script functions
33+
_SCALE_VALUE = 0.25
34+
35+
36+
# --- Script models ---
37+
38+
39+
@script()
40+
def _mha_with_scalar_scale(query, key, value, scale):
41+
scaled_q = op.Mul(query, scale)
42+
return msft_op.MultiHeadAttention(scaled_q, key, value, num_heads=_NUM_HEADS)
43+
44+
45+
@script()
46+
def _mha_no_scale(query, key, value):
47+
return msft_op.MultiHeadAttention(query, key, value, num_heads=_NUM_HEADS)
48+
49+
50+
@script()
51+
def _mha_with_dynamic_scale(query, key, value, scale):
52+
"""Scale is a graph input (not constant) → fusion should not apply."""
53+
scaled_q = op.Mul(query, scale)
54+
return msft_op.MultiHeadAttention(scaled_q, key, value, num_heads=_NUM_HEADS)
55+
56+
57+
class FuseMHAScaleTest(unittest.TestCase):
58+
"""Unit tests for the FuseMHAScale rewrite rule."""
59+
60+
def _build(self, script_fn, input_types, output_types) -> ir.Model:
61+
model_proto = script_fn.to_model_proto(
62+
input_types=input_types, output_types=output_types
63+
)
64+
model = ir.serde.deserialize_model(model_proto)
65+
optimize(model)
66+
return model
67+
68+
def _apply(self, model: ir.Model) -> int:
69+
return fuse_mha_scale(model)
70+
71+
def _get_mha_node(self, model: ir.Model) -> ir.Node | None:
72+
for node in model.graph:
73+
if node.op_type == "MultiHeadAttention" and node.domain == "com.microsoft":
74+
return node
75+
return None
76+
77+
def _make_scale_constant(self, model: ir.Model, scale_value: float):
78+
"""Convert the ``scale`` graph input into a constant initializer."""
79+
for node in model.graph:
80+
if node.op_type == "Mul":
81+
scale_input = node.inputs[1]
82+
assert scale_input is not None
83+
scale_input.const_value = ir.tensor(np.array([scale_value], dtype=np.float32))
84+
model.graph.inputs.pop()
85+
return
86+
raise RuntimeError("Mul node not found")
87+
88+
def _check_numerical_equivalence(
89+
self, model: ir.Model, inputs: dict, scale_value: float, expected_count: int
90+
):
91+
# Run original model *before* making scale constant (scale is a graph input)
92+
inputs_with_scale = {
93+
**inputs,
94+
"scale": np.array([scale_value], dtype=np.float32),
95+
}
96+
original_output = test_utils.ort_run("Original", model, inputs_with_scale)
97+
# Now convert scale to constant and apply fusion
98+
self._make_scale_constant(model, scale_value)
99+
count = self._apply(model)
100+
self.assertEqual(count, expected_count)
101+
fused_output = test_utils.ort_run("Fused", model, inputs)
102+
test_utils.assert_allclose(original_output, fused_output)
103+
104+
# --- Positive tests ---
105+
106+
def _build_scale_model(self):
107+
return self._build(
108+
_mha_with_scalar_scale,
109+
input_types=[
110+
FLOAT["B", "S", _D],
111+
FLOAT["B", "S", _D],
112+
FLOAT["B", "S", _D],
113+
FLOAT[1],
114+
],
115+
output_types=[FLOAT["B", "S", _D]],
116+
)
117+
118+
def _make_inputs(self):
119+
return {
120+
"query": np.random.randn(_B, _S, _D).astype(np.float32),
121+
"key": np.random.randn(_B, _S, _D).astype(np.float32),
122+
"value": np.random.randn(_B, _S, _D).astype(np.float32),
123+
}
124+
125+
def test_scalar_scale_fused(self):
126+
"""Mul(query, scalar_constant) before MHA → scale absorbed into attribute."""
127+
model = self._build_scale_model()
128+
inputs = self._make_inputs()
129+
self._check_numerical_equivalence(model, inputs, _SCALE_VALUE, expected_count=1)
130+
# Verify Mul is gone and MHA has scale attribute
131+
self.assertFalse(any(n.op_type == "Mul" for n in model.graph), "Mul should be removed")
132+
mha_node = self._get_mha_node(model)
133+
self.assertIsNotNone(mha_node)
134+
scale_attr = mha_node.attributes.get_float("scale", None)
135+
self.assertIsNotNone(scale_attr)
136+
expected = _SCALE_VALUE * _DEFAULT_SCALE
137+
self.assertAlmostEqual(scale_attr, expected, places=5)
138+
139+
def test_integer_scale_fused(self):
140+
"""Integer scale constant (e.g. 2) → still fused."""
141+
model = self._build_scale_model()
142+
inputs = self._make_inputs()
143+
self._check_numerical_equivalence(model, inputs, 2.0, expected_count=1)
144+
mha_node = self._get_mha_node(model)
145+
self.assertIsNotNone(mha_node)
146+
scale_attr = mha_node.attributes.get_float("scale", None)
147+
self.assertIsNotNone(scale_attr)
148+
expected = 2.0 * _DEFAULT_SCALE
149+
self.assertAlmostEqual(scale_attr, expected, places=5)
150+
151+
def test_scale_combined_with_existing_scale_attr(self):
152+
"""MHA already has a scale attribute → external scale is multiplied with it."""
153+
model = self._build_scale_model()
154+
# Set existing MHA scale attribute before any ORT run
155+
existing_scale = 0.1
156+
for node in model.graph:
157+
if node.op_type == "MultiHeadAttention" and node.domain == "com.microsoft":
158+
node.attributes["scale"] = ir.AttrFloat32("scale", existing_scale)
159+
160+
inputs = self._make_inputs()
161+
self._check_numerical_equivalence(model, inputs, _SCALE_VALUE, expected_count=1)
162+
mha_node = self._get_mha_node(model)
163+
self.assertIsNotNone(mha_node)
164+
scale_attr = mha_node.attributes.get_float("scale", None)
165+
self.assertIsNotNone(scale_attr)
166+
expected = _SCALE_VALUE * existing_scale
167+
self.assertAlmostEqual(scale_attr, expected, places=5)
168+
169+
# --- Negative tests ---
170+
171+
def test_no_mul_no_fusion(self):
172+
"""No Mul before MHA → rule does not match."""
173+
model = self._build(
174+
_mha_no_scale,
175+
input_types=[FLOAT["B", "S", _D], FLOAT["B", "S", _D], FLOAT["B", "S", _D]],
176+
output_types=[FLOAT["B", "S", _D]],
177+
)
178+
count = self._apply(model)
179+
self.assertEqual(count, 0)
180+
181+
def test_dynamic_scale_no_fusion(self):
182+
"""Scale is a non-constant graph input → check rejects."""
183+
model = self._build(
184+
_mha_with_dynamic_scale,
185+
input_types=[
186+
FLOAT["B", "S", _D],
187+
FLOAT["B", "S", _D],
188+
FLOAT["B", "S", _D],
189+
FLOAT[1],
190+
],
191+
output_types=[FLOAT["B", "S", _D]],
192+
)
193+
count = self._apply(model)
194+
self.assertEqual(count, 0)
195+
196+
197+
if __name__ == "__main__":
198+
unittest.main()

0 commit comments

Comments
 (0)