Skip to content

Commit eea7f8e

Browse files
gramalingamCopilot
andcommitted
Add transpose perm corruption negative test for MHA
The original negative test had a tuple-vs-list comparison bug: get_ints() returns a tuple, so perm == [0,2,1,3] was always False, meaning the corruption never happened. Fixed to compare with tuple. Confirmed the fusion correctly rejects mismatched Transpose perms. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 1a024e8 commit eea7f8e

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

onnxscript/rewriter/ort_fusions/mha_unit_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,18 @@ def test_rank2_query_no_fusion(self):
254254
count = self._apply(model)
255255
self.assertEqual(count, 0)
256256

257+
def test_wrong_query_transpose_perm_no_fusion(self):
258+
"""Corrupted query Transpose perm → pattern should not match."""
259+
model = self._build(_mha_basic_key_transposed, self._3D, self._OUT_1)
260+
for node in model.graph:
261+
if node.op_type == "Transpose":
262+
perm = node.attributes.get_ints("perm")
263+
if perm == (0, 2, 1, 3):
264+
node.attributes["perm"] = ir.AttrInt64s("perm", [0, 1, 2, 3])
265+
break
266+
count = self._apply(model)
267+
self.assertEqual(count, 0)
268+
257269

258270
if __name__ == "__main__":
259271
unittest.main()

0 commit comments

Comments
 (0)