Skip to content

Commit 95b5977

Browse files
gramalingamCopilot
andcommitted
Use symbolic dims ("B", "S") in mha_bias_test input/output types
Models now use symbolic dimension names in input_types/output_types to better reflect real-world models, while concrete values (_B, _S) are still used for numpy test data generation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5869997 commit 95b5977

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

onnxscript/rewriter/ort_fusions/mha_bias_test.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ def test_all_three_biases_fused(self):
118118
model = self._build(
119119
_mha_all_biases,
120120
input_types=[
121-
FLOAT[_B, _S, _D],
122-
FLOAT[_B, _S, _Dk],
123-
FLOAT[_B, _S, _Dv],
121+
FLOAT["B", "S", _D],
122+
FLOAT["B", "S", _Dk],
123+
FLOAT["B", "S", _Dv],
124124
FLOAT[_D],
125125
FLOAT[_Dk],
126126
FLOAT[_Dv],
127127
],
128-
output_types=[FLOAT[_B, _S, _D]],
128+
output_types=[FLOAT["B", "S", _D]],
129129
)
130130
inputs = {
131131
"query_matmul": np.random.randn(_B, _S, _D).astype(np.float32),
@@ -145,12 +145,12 @@ def test_only_q_bias(self):
145145
model = self._build(
146146
_mha_q_bias_only,
147147
input_types=[
148-
FLOAT[_B, _S, _D],
149-
FLOAT[_B, _S, _Dk],
150-
FLOAT[_B, _S, _Dv],
148+
FLOAT["B", "S", _D],
149+
FLOAT["B", "S", _Dk],
150+
FLOAT["B", "S", _Dv],
151151
FLOAT[_D],
152152
],
153-
output_types=[FLOAT[_B, _S, _D]],
153+
output_types=[FLOAT["B", "S", _D]],
154154
)
155155
inputs = {
156156
"query_matmul": np.random.randn(_B, _S, _D).astype(np.float32),
@@ -167,12 +167,12 @@ def test_only_k_bias(self):
167167
model = self._build(
168168
_mha_k_bias_only,
169169
input_types=[
170-
FLOAT[_B, _S, _D],
171-
FLOAT[_B, _S, _Dk],
172-
FLOAT[_B, _S, _Dv],
170+
FLOAT["B", "S", _D],
171+
FLOAT["B", "S", _Dk],
172+
FLOAT["B", "S", _Dv],
173173
FLOAT[_Dk],
174174
],
175-
output_types=[FLOAT[_B, _S, _D]],
175+
output_types=[FLOAT["B", "S", _D]],
176176
)
177177
inputs = {
178178
"query_matmul": np.random.randn(_B, _S, _D).astype(np.float32),
@@ -188,12 +188,12 @@ def test_only_v_bias(self):
188188
model = self._build(
189189
_mha_v_bias_only,
190190
input_types=[
191-
FLOAT[_B, _S, _D],
192-
FLOAT[_B, _S, _Dk],
193-
FLOAT[_B, _S, _Dv],
191+
FLOAT["B", "S", _D],
192+
FLOAT["B", "S", _Dk],
193+
FLOAT["B", "S", _Dv],
194194
FLOAT[_Dv],
195195
],
196-
output_types=[FLOAT[_B, _S, _D]],
196+
output_types=[FLOAT["B", "S", _D]],
197197
)
198198
inputs = {
199199
"query_matmul": np.random.randn(_B, _S, _D).astype(np.float32),
@@ -209,13 +209,13 @@ def test_q_and_k_bias_only(self):
209209
model = self._build(
210210
_mha_qk_biases,
211211
input_types=[
212-
FLOAT[_B, _S, _D],
213-
FLOAT[_B, _S, _Dk],
214-
FLOAT[_B, _S, _Dv],
212+
FLOAT["B", "S", _D],
213+
FLOAT["B", "S", _Dk],
214+
FLOAT["B", "S", _Dv],
215215
FLOAT[_D],
216216
FLOAT[_Dk],
217217
],
218-
output_types=[FLOAT[_B, _S, _D]],
218+
output_types=[FLOAT["B", "S", _D]],
219219
)
220220
inputs = {
221221
"query_matmul": np.random.randn(_B, _S, _D).astype(np.float32),
@@ -232,8 +232,8 @@ def test_no_biases_no_fusion(self):
232232
"""No bias Adds at all → rule should not apply."""
233233
model = self._build(
234234
_mha_no_biases,
235-
input_types=[FLOAT[_B, _S, _D], FLOAT[_B, _S, _Dk], FLOAT[_B, _S, _Dv]],
236-
output_types=[FLOAT[_B, _S, _D]],
235+
input_types=[FLOAT["B", "S", _D], FLOAT["B", "S", _Dk], FLOAT["B", "S", _Dv]],
236+
output_types=[FLOAT["B", "S", _D]],
237237
)
238238
count = self._apply(model)
239239
self.assertEqual(count, 0)
@@ -244,12 +244,12 @@ def test_int32_dtype_no_fusion(self):
244244
model = self._build(
245245
_mha_int32_with_bias,
246246
input_types=[
247-
INT32[_B, _S, _D],
248-
INT32[_B, _S, _Dk],
249-
INT32[_B, _S, _Dv],
247+
INT32["B", "S", _D],
248+
INT32["B", "S", _Dk],
249+
INT32["B", "S", _Dv],
250250
INT32[_D],
251251
],
252-
output_types=[INT32[_B, _S, _D]],
252+
output_types=[INT32["B", "S", _D]],
253253
)
254254
count = self._apply(model)
255255
self.assertEqual(count, 0)
@@ -259,12 +259,12 @@ def test_shape_mismatch_no_fusion(self):
259259
model = self._build(
260260
_mha_rank2_query_with_bias,
261261
input_types=[
262-
FLOAT[_S, _D],
263-
FLOAT[_B, _S, _Dk],
264-
FLOAT[_B, _S, _Dv],
262+
FLOAT["S", _D],
263+
FLOAT["B", "S", _Dk],
264+
FLOAT["B", "S", _Dv],
265265
FLOAT[_D],
266266
],
267-
output_types=[FLOAT[_B, _S, _D]],
267+
output_types=[FLOAT["B", "S", _D]],
268268
)
269269
count = self._apply(model)
270270
self.assertEqual(count, 0)

0 commit comments

Comments
 (0)