Skip to content

Commit 16e760f

Browse files
gramalingamCopilot
andcommitted
Add ORT numerical validation to skip normalization tests
All 6 positive tests now verify original vs fused model outputs match using ORT inference. Uses concrete dims for test data while keeping the structural assertions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 543a584 commit 16e760f

File tree

1 file changed

+41
-11
lines changed

1 file changed

+41
-11
lines changed

onnxscript/rewriter/ort_fusions/skip_normalization_unit_test.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from onnxscript import FLOAT, script
2424
from onnxscript import opset18 as op
2525
from onnxscript.optimizer import optimize
26+
from onnxscript.rewriter.ort_fusions import _test_utils as test_utils
2627
from onnxscript.rewriter.ort_fusions.skip_normalization import (
2728
fuse_skip_layer_normalization,
2829
fuse_skip_rms_normalization,
@@ -111,9 +112,28 @@ def _build(self, script_fn, input_types, output_types) -> ir.Model:
111112
def _count_op(self, model: ir.Model, op_type: str, domain: str = "") -> int:
112113
return sum(1 for n in model.graph if n.op_type == op_type and n.domain == domain)
113114

114-
_3D = FLOAT["B", "S", _D]
115+
def _check_numerical_equivalence(
116+
self, model: ir.Model, inputs: dict, fuse_fn, expected_count: int
117+
):
118+
original_output = test_utils.ort_run("Original", model, inputs)
119+
count = fuse_fn(model)
120+
self.assertGreaterEqual(count, expected_count)
121+
fused_output = test_utils.ort_run("Fused", model, inputs)
122+
test_utils.assert_allclose(original_output, fused_output)
123+
124+
_3D = FLOAT[_B, _S, _D]
115125
_1D = FLOAT[_D]
116126

127+
def _make_inputs(self, *names):
128+
shapes = {
129+
"input": (_B, _S, _D),
130+
"skip": (_B, _S, _D),
131+
"gamma": (_D,),
132+
"beta": (_D,),
133+
"bias": (_D,),
134+
}
135+
return {n: np.random.randn(*shapes[n]).astype(np.float32) for n in names}
136+
117137
# ---- Skip RMS Norm positive tests ----
118138

119139
@parameterized.expand(
@@ -129,8 +149,10 @@ def test_skip_rms_no_bias(self, _name, script_fn):
129149
input_types=[self._3D, self._3D, self._1D],
130150
output_types=[self._3D],
131151
)
132-
count = fuse_skip_rms_normalization(model)
133-
self.assertGreater(count, 0)
152+
inputs = self._make_inputs("input", "skip", "gamma")
153+
self._check_numerical_equivalence(
154+
model, inputs, fuse_skip_rms_normalization, expected_count=1
155+
)
134156
self.assertEqual(
135157
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
136158
)
@@ -143,8 +165,10 @@ def test_skip_rms_post_bias(self):
143165
input_types=[self._3D, self._3D, self._1D, self._1D],
144166
output_types=[self._3D],
145167
)
146-
count = fuse_skip_rms_normalization(model)
147-
self.assertGreater(count, 0)
168+
inputs = self._make_inputs("input", "skip", "gamma", "bias")
169+
self._check_numerical_equivalence(
170+
model, inputs, fuse_skip_rms_normalization, expected_count=1
171+
)
148172
self.assertEqual(
149173
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
150174
)
@@ -156,8 +180,10 @@ def test_skip_rms_pre_bias(self):
156180
input_types=[self._3D, self._3D, self._1D, self._1D],
157181
output_types=[self._3D],
158182
)
159-
count = fuse_skip_rms_normalization(model)
160-
self.assertGreater(count, 0)
183+
inputs = self._make_inputs("input", "skip", "gamma", "bias")
184+
self._check_numerical_equivalence(
185+
model, inputs, fuse_skip_rms_normalization, expected_count=1
186+
)
161187
self.assertEqual(
162188
self._count_op(model, "SkipSimplifiedLayerNormalization", "com.microsoft"), 1
163189
)
@@ -171,8 +197,10 @@ def test_skip_ln_no_bias(self):
171197
input_types=[self._3D, self._3D, self._1D, self._1D],
172198
output_types=[self._3D],
173199
)
174-
count = fuse_skip_layer_normalization(model)
175-
self.assertGreater(count, 0)
200+
inputs = self._make_inputs("input", "skip", "gamma", "beta")
201+
self._check_numerical_equivalence(
202+
model, inputs, fuse_skip_layer_normalization, expected_count=1
203+
)
176204
self.assertEqual(self._count_op(model, "SkipLayerNormalization", "com.microsoft"), 1)
177205
self.assertEqual(self._count_op(model, "LayerNormalization"), 0)
178206

@@ -183,8 +211,10 @@ def test_skip_ln_post_bias(self):
183211
input_types=[self._3D, self._3D, self._1D, self._1D, self._1D],
184212
output_types=[self._3D],
185213
)
186-
count = fuse_skip_layer_normalization(model)
187-
self.assertGreater(count, 0)
214+
inputs = self._make_inputs("input", "skip", "gamma", "beta", "bias")
215+
self._check_numerical_equivalence(
216+
model, inputs, fuse_skip_layer_normalization, expected_count=1
217+
)
188218
self.assertEqual(self._count_op(model, "SkipLayerNormalization", "com.microsoft"), 1)
189219

190220
# ---- Negative tests ----

0 commit comments

Comments
 (0)