Skip to content

Commit 543a584

Browse files
gramalingamCopilot
andcommitted
Add numerical validation to rms_normalization extended tests
Use ONNX reference implementation (ORT lacks RMSNormalization kernel) to verify original and fused models produce identical results for float32 and fp16 tests. Double precision remains structural-only since the reference impl doesn't support stash_type=DOUBLE. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5461135 commit 543a584

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

onnxscript/rewriter/rules/fusion/_rms_normalization_extended_test.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from onnxscript import DOUBLE, FLOAT, FLOAT16, script
1919
from onnxscript import opset18 as op
2020
from onnxscript.optimizer import optimize
21+
from onnxscript.rewriter import testing as rewriter_testing
2122
from onnxscript.rewriter.rules.fusion._rms_normalization import fuse_rms_normalization
2223

2324
_EPS = ir.tensor(np.array([1e-6], dtype=np.float32))
@@ -100,6 +101,8 @@ def _rms_int_input(x, scale):
100101
class RmsNormOnnxFusionExtendedTest(unittest.TestCase):
101102
"""Extended tests for RmsNormFusion (rules/fusion variant producing RMSNormalization)."""
102103

104+
_B, _S, _D = 2, 4, 16
105+
103106
def _build(self, script_fn, input_types, output_types) -> ir.Model:
104107
model_proto = script_fn.to_model_proto(
105108
input_types=input_types, output_types=output_types
@@ -111,6 +114,20 @@ def _build(self, script_fn, input_types, output_types) -> ir.Model:
111114
def _count_op(self, model: ir.Model, op_type: str) -> int:
112115
return sum(1 for n in model.graph if n.op_type == op_type)
113116

117+
def _check_numerical_equivalence(self, model: ir.Model, inputs: dict, expected_count: int):
118+
"""Apply fusion and verify numerical equivalence using ONNX reference impl.
119+
120+
ORT does not yet have a kernel for RMSNormalization, so we use
121+
the ONNX reference implementation for validation.
122+
"""
123+
original_proto = ir.serde.serialize_model(model)
124+
count = fuse_rms_normalization(model)
125+
self.assertEqual(count, expected_count)
126+
fused_proto = ir.serde.serialize_model(model)
127+
rewriter_testing.assert_numerically_equal(
128+
original_proto, fused_proto, args=inputs, use_reference=True
129+
)
130+
114131
# --- Positive tests ---
115132

116133
@parameterized.expand(
@@ -123,31 +140,41 @@ def test_mul_order_variants(self, _name, script_fn):
123140
"""Both Mul orderings should fuse to RMSNormalization."""
124141
model = self._build(
125142
script_fn,
126-
input_types=[FLOAT["B", "S", 16], FLOAT[16]],
127-
output_types=[FLOAT["B", "S", 16]],
143+
input_types=[FLOAT[self._B, self._S, self._D], FLOAT[self._D]],
144+
output_types=[FLOAT[self._B, self._S, self._D]],
128145
)
129-
count = fuse_rms_normalization(model)
130-
self.assertEqual(count, 1)
146+
inputs = {
147+
"x": np.random.randn(self._B, self._S, self._D).astype(np.float32),
148+
"scale": np.random.randn(self._D).astype(np.float32),
149+
}
150+
self._check_numerical_equivalence(model, inputs, expected_count=1)
131151
self.assertEqual(self._count_op(model, "RMSNormalization"), 1)
132152
self.assertEqual(self._count_op(model, "Pow"), 0)
133153

134154
def test_mixed_precision_cast(self):
135155
"""fp16 input Cast to fp32 for compute, Cast back → fuses."""
136156
model = self._build(
137157
_rms_mixed_precision,
138-
input_types=[FLOAT16["B", "S", 16], FLOAT16[16]],
139-
output_types=[FLOAT16["B", "S", 16]],
158+
input_types=[FLOAT16[self._B, self._S, self._D], FLOAT16[self._D]],
159+
output_types=[FLOAT16[self._B, self._S, self._D]],
140160
)
141-
count = fuse_rms_normalization(model)
142-
self.assertEqual(count, 1)
161+
inputs = {
162+
"x": np.random.randn(self._B, self._S, self._D).astype(np.float16),
163+
"scale": np.random.randn(self._D).astype(np.float16),
164+
}
165+
self._check_numerical_equivalence(model, inputs, expected_count=1)
143166
self.assertEqual(self._count_op(model, "RMSNormalization"), 1)
144167

145168
def test_double_precision(self):
146-
"""Double-precision inputs → fuses (double is a valid compute type)."""
169+
"""Double-precision inputs → fuses (double is a valid compute type).
170+
171+
Structural check only: ONNX reference impl does not support
172+
RMSNormalization with stash_type=DOUBLE.
173+
"""
147174
model = self._build(
148175
_rms_double,
149-
input_types=[DOUBLE["B", "S", 16], DOUBLE[16]],
150-
output_types=[DOUBLE["B", "S", 16]],
176+
input_types=[DOUBLE[self._B, self._S, self._D], DOUBLE[self._D]],
177+
output_types=[DOUBLE[self._B, self._S, self._D]],
151178
)
152179
count = fuse_rms_normalization(model)
153180
self.assertEqual(count, 1)

0 commit comments

Comments
 (0)