Skip to content

Commit d68e0d0

Browse files
gramalingamCopilot
andcommitted
Use op.Constant for scale in mha_scale tests
Use op.Constant(value=ir.tensor(...)) inside @script() functions to define scale as a graph constant directly, instead of creating it as a graph input then converting post-hoc. Simpler and more realistic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 03b0c7e commit d68e0d0

File tree

1 file changed

+31
-62
lines changed

1 file changed

+31
-62
lines changed

onnxscript/rewriter/ort_fusions/mha_scale_test.py

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,24 @@
2929
_HEAD_SIZE = _D // _NUM_HEADS
3030
_DEFAULT_SCALE = 1.0 / math.sqrt(_HEAD_SIZE)
3131

32-
# Pre-computed constant for use inside @script functions
33-
_SCALE_VALUE = 0.25
32+
# Constants for use inside @script functions
33+
_SCALE_TENSOR_025 = ir.tensor(np.array([0.25], dtype=np.float32))
34+
_SCALE_TENSOR_2 = ir.tensor(np.array([2.0], dtype=np.float32))
3435

3536

3637
# --- Script models ---
3738

3839

3940
@script()
40-
def _mha_with_scalar_scale(query, key, value, scale):
41+
def _mha_with_scale_025(query, key, value):
42+
scale = op.Constant(value=_SCALE_TENSOR_025)
43+
scaled_q = op.Mul(query, scale)
44+
return msft_op.MultiHeadAttention(scaled_q, key, value, num_heads=_NUM_HEADS)
45+
46+
47+
@script()
48+
def _mha_with_scale_2(query, key, value):
49+
scale = op.Constant(value=_SCALE_TENSOR_2)
4150
scaled_q = op.Mul(query, scale)
4251
return msft_op.MultiHeadAttention(scaled_q, key, value, num_heads=_NUM_HEADS)
4352

@@ -74,73 +83,43 @@ def _get_mha_node(self, model: ir.Model) -> ir.Node | None:
7483
return node
7584
return None
7685

77-
def _make_scale_constant(self, model: ir.Model, scale_value: float):
78-
"""Convert the ``scale`` graph input into a constant initializer."""
79-
for node in model.graph:
80-
if node.op_type == "Mul":
81-
scale_input = node.inputs[1]
82-
assert scale_input is not None
83-
scale_input.const_value = ir.tensor(np.array([scale_value], dtype=np.float32))
84-
model.graph.inputs.pop()
85-
return
86-
raise RuntimeError("Mul node not found")
87-
88-
def _check_numerical_equivalence(
89-
self, model: ir.Model, inputs: dict, scale_value: float, expected_count: int
90-
):
91-
# Run original model *before* making scale constant (scale is a graph input)
92-
inputs_with_scale = {
93-
**inputs,
94-
"scale": np.array([scale_value], dtype=np.float32),
95-
}
96-
original_output = test_utils.ort_run("Original", model, inputs_with_scale)
97-
# Now convert scale to constant and apply fusion
98-
self._make_scale_constant(model, scale_value)
86+
_3D = [FLOAT["B", "S", _D]] * 3
87+
_OUT = [FLOAT["B", "S", _D]]
88+
89+
def _check_numerical_equivalence(self, model: ir.Model, inputs: dict, expected_count: int):
90+
original_output = test_utils.ort_run("Original", model, inputs)
9991
count = self._apply(model)
10092
self.assertEqual(count, expected_count)
10193
fused_output = test_utils.ort_run("Fused", model, inputs)
10294
test_utils.assert_allclose(original_output, fused_output)
10395

104-
# --- Positive tests ---
105-
106-
def _build_scale_model(self):
107-
return self._build(
108-
_mha_with_scalar_scale,
109-
input_types=[
110-
FLOAT["B", "S", _D],
111-
FLOAT["B", "S", _D],
112-
FLOAT["B", "S", _D],
113-
FLOAT[1],
114-
],
115-
output_types=[FLOAT["B", "S", _D]],
116-
)
117-
11896
def _make_inputs(self):
11997
return {
12098
"query": np.random.randn(_B, _S, _D).astype(np.float32),
12199
"key": np.random.randn(_B, _S, _D).astype(np.float32),
122100
"value": np.random.randn(_B, _S, _D).astype(np.float32),
123101
}
124102

103+
# --- Positive tests ---
104+
125105
def test_scalar_scale_fused(self):
126106
"""Mul(query, scalar_constant) before MHA → scale absorbed into attribute."""
127-
model = self._build_scale_model()
107+
model = self._build(_mha_with_scale_025, self._3D, self._OUT)
128108
inputs = self._make_inputs()
129-
self._check_numerical_equivalence(model, inputs, _SCALE_VALUE, expected_count=1)
130-
# Verify Mul is gone and MHA has scale attribute
109+
self._check_numerical_equivalence(model, inputs, expected_count=1)
131110
self.assertFalse(any(n.op_type == "Mul" for n in model.graph), "Mul should be removed")
132111
mha_node = self._get_mha_node(model)
133112
self.assertIsNotNone(mha_node)
134113
scale_attr = mha_node.attributes.get_float("scale", None)
135114
self.assertIsNotNone(scale_attr)
136-
expected = _SCALE_VALUE * _DEFAULT_SCALE
115+
expected = 0.25 * _DEFAULT_SCALE
137116
self.assertAlmostEqual(scale_attr, expected, places=5)
138117

139118
def test_integer_scale_fused(self):
140-
"""Integer scale constant (e.g. 2) → still fused."""
141-
model = self._build_scale_model()
119+
"""Scale constant of 2.0 → still fused."""
120+
model = self._build(_mha_with_scale_2, self._3D, self._OUT)
142121
inputs = self._make_inputs()
143-
self._check_numerical_equivalence(model, inputs, 2.0, expected_count=1)
122+
self._check_numerical_equivalence(model, inputs, expected_count=1)
144123
mha_node = self._get_mha_node(model)
145124
self.assertIsNotNone(mha_node)
146125
scale_attr = mha_node.attributes.get_float("scale", None)
@@ -150,45 +129,35 @@ def test_integer_scale_fused(self):
150129

151130
def test_scale_combined_with_existing_scale_attr(self):
152131
"""MHA already has a scale attribute → external scale is multiplied with it."""
153-
model = self._build_scale_model()
154-
# Set existing MHA scale attribute before any ORT run
132+
model = self._build(_mha_with_scale_025, self._3D, self._OUT)
155133
existing_scale = 0.1
156134
for node in model.graph:
157135
if node.op_type == "MultiHeadAttention" and node.domain == "com.microsoft":
158136
node.attributes["scale"] = ir.AttrFloat32("scale", existing_scale)
159137

160138
inputs = self._make_inputs()
161-
self._check_numerical_equivalence(model, inputs, _SCALE_VALUE, expected_count=1)
139+
self._check_numerical_equivalence(model, inputs, expected_count=1)
162140
mha_node = self._get_mha_node(model)
163141
self.assertIsNotNone(mha_node)
164142
scale_attr = mha_node.attributes.get_float("scale", None)
165143
self.assertIsNotNone(scale_attr)
166-
expected = _SCALE_VALUE * existing_scale
144+
expected = 0.25 * existing_scale
167145
self.assertAlmostEqual(scale_attr, expected, places=5)
168146

169147
# --- Negative tests ---
170148

171149
def test_no_mul_no_fusion(self):
172150
"""No Mul before MHA → rule does not match."""
173-
model = self._build(
174-
_mha_no_scale,
175-
input_types=[FLOAT["B", "S", _D], FLOAT["B", "S", _D], FLOAT["B", "S", _D]],
176-
output_types=[FLOAT["B", "S", _D]],
177-
)
151+
model = self._build(_mha_no_scale, self._3D, self._OUT)
178152
count = self._apply(model)
179153
self.assertEqual(count, 0)
180154

181155
def test_dynamic_scale_no_fusion(self):
182156
"""Scale is a non-constant graph input → check rejects."""
183157
model = self._build(
184158
_mha_with_dynamic_scale,
185-
input_types=[
186-
FLOAT["B", "S", _D],
187-
FLOAT["B", "S", _D],
188-
FLOAT["B", "S", _D],
189-
FLOAT[1],
190-
],
191-
output_types=[FLOAT["B", "S", _D]],
159+
input_types=[*self._3D, FLOAT[1]],
160+
output_types=self._OUT,
192161
)
193162
count = self._apply(model)
194163
self.assertEqual(count, 0)

0 commit comments

Comments
 (0)