Skip to content

Commit ad8a8b0

Browse files
gramalingamCopilot
andcommitted
Skip numerical validation when onnx lacks RMSNormalization schema
Serializing fused models with RMSNormalization requires onnx opset >= 23. On older onnx versions, tests now fall back to structural checks only (fusion count + op type assertions still run). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent cb617c0 commit ad8a8b0

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

onnxscript/rewriter/rules/fusion/_rms_normalization_extended_test.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import unittest
1313

1414
import numpy as np
15+
import onnx
1516
import onnx_ir as ir
1617
from parameterized import parameterized
1718

@@ -24,6 +25,10 @@
2425
_EPS = ir.tensor(np.array([1e-6], dtype=np.float32))
2526
_EPS_D = ir.tensor(np.array([1e-6], dtype=np.float64))
2627

28+
_has_rms_normalization_schema = hasattr(onnx.defs, "get_schema") and (
29+
onnx.defs.onnx_opset_version() >= 23
30+
)
31+
2732

2833
# --- mul_order=False: Mul(scale, normalized) ---
2934

@@ -118,15 +123,19 @@ def _check_numerical_equivalence(self, model: ir.Model, inputs: dict, expected_c
118123
"""Apply fusion and verify numerical equivalence using ONNX reference impl.
119124
120125
ORT does not yet have a kernel for RMSNormalization, so we use
121-
the ONNX reference implementation for validation.
126+
the ONNX reference implementation for validation. Serialization of
127+
the fused model requires the RMSNormalization schema (onnx opset >= 23),
128+
so numerical validation is skipped on older onnx versions.
122129
"""
123-
original_proto = ir.serde.serialize_model(model)
130+
if _has_rms_normalization_schema:
131+
original_proto = ir.serde.serialize_model(model)
124132
count = fuse_rms_normalization(model)
125133
self.assertEqual(count, expected_count)
126-
fused_proto = ir.serde.serialize_model(model)
127-
rewriter_testing.assert_numerically_equal(
128-
original_proto, fused_proto, args=inputs, use_reference=True
129-
)
134+
if _has_rms_normalization_schema:
135+
fused_proto = ir.serde.serialize_model(model)
136+
rewriter_testing.assert_numerically_equal(
137+
original_proto, fused_proto, args=inputs, use_reference=True
138+
)
130139

131140
# --- Positive tests ---
132141

0 commit comments

Comments
 (0)