2323from onnxscript import FLOAT , script
2424from onnxscript import opset18 as op
2525from onnxscript .optimizer import optimize
26+ from onnxscript .rewriter .ort_fusions import _test_utils as test_utils
2627from onnxscript .rewriter .ort_fusions .skip_normalization import (
2728 fuse_skip_layer_normalization ,
2829 fuse_skip_rms_normalization ,
@@ -111,9 +112,28 @@ def _build(self, script_fn, input_types, output_types) -> ir.Model:
111112 def _count_op (self , model : ir .Model , op_type : str , domain : str = "" ) -> int :
112113 return sum (1 for n in model .graph if n .op_type == op_type and n .domain == domain )
113114
114- _3D = FLOAT ["B" , "S" , _D ]
115+ def _check_numerical_equivalence (
116+ self , model : ir .Model , inputs : dict , fuse_fn , expected_count : int
117+ ):
118+ original_output = test_utils .ort_run ("Original" , model , inputs )
119+ count = fuse_fn (model )
120+ self .assertGreaterEqual (count , expected_count )
121+ fused_output = test_utils .ort_run ("Fused" , model , inputs )
122+ test_utils .assert_allclose (original_output , fused_output )
123+
124+ _3D = FLOAT [_B , _S , _D ]
115125 _1D = FLOAT [_D ]
116126
127+ def _make_inputs (self , * names ):
128+ shapes = {
129+ "input" : (_B , _S , _D ),
130+ "skip" : (_B , _S , _D ),
131+ "gamma" : (_D ,),
132+ "beta" : (_D ,),
133+ "bias" : (_D ,),
134+ }
135+ return {n : np .random .randn (* shapes [n ]).astype (np .float32 ) for n in names }
136+
117137 # ---- Skip RMS Norm positive tests ----
118138
119139 @parameterized .expand (
@@ -129,8 +149,10 @@ def test_skip_rms_no_bias(self, _name, script_fn):
129149 input_types = [self ._3D , self ._3D , self ._1D ],
130150 output_types = [self ._3D ],
131151 )
132- count = fuse_skip_rms_normalization (model )
133- self .assertGreater (count , 0 )
152+ inputs = self ._make_inputs ("input" , "skip" , "gamma" )
153+ self ._check_numerical_equivalence (
154+ model , inputs , fuse_skip_rms_normalization , expected_count = 1
155+ )
134156 self .assertEqual (
135157 self ._count_op (model , "SkipSimplifiedLayerNormalization" , "com.microsoft" ), 1
136158 )
@@ -143,8 +165,10 @@ def test_skip_rms_post_bias(self):
143165 input_types = [self ._3D , self ._3D , self ._1D , self ._1D ],
144166 output_types = [self ._3D ],
145167 )
146- count = fuse_skip_rms_normalization (model )
147- self .assertGreater (count , 0 )
168+ inputs = self ._make_inputs ("input" , "skip" , "gamma" , "bias" )
169+ self ._check_numerical_equivalence (
170+ model , inputs , fuse_skip_rms_normalization , expected_count = 1
171+ )
148172 self .assertEqual (
149173 self ._count_op (model , "SkipSimplifiedLayerNormalization" , "com.microsoft" ), 1
150174 )
@@ -156,8 +180,10 @@ def test_skip_rms_pre_bias(self):
156180 input_types = [self ._3D , self ._3D , self ._1D , self ._1D ],
157181 output_types = [self ._3D ],
158182 )
159- count = fuse_skip_rms_normalization (model )
160- self .assertGreater (count , 0 )
183+ inputs = self ._make_inputs ("input" , "skip" , "gamma" , "bias" )
184+ self ._check_numerical_equivalence (
185+ model , inputs , fuse_skip_rms_normalization , expected_count = 1
186+ )
161187 self .assertEqual (
162188 self ._count_op (model , "SkipSimplifiedLayerNormalization" , "com.microsoft" ), 1
163189 )
@@ -171,8 +197,10 @@ def test_skip_ln_no_bias(self):
171197 input_types = [self ._3D , self ._3D , self ._1D , self ._1D ],
172198 output_types = [self ._3D ],
173199 )
174- count = fuse_skip_layer_normalization (model )
175- self .assertGreater (count , 0 )
200+ inputs = self ._make_inputs ("input" , "skip" , "gamma" , "beta" )
201+ self ._check_numerical_equivalence (
202+ model , inputs , fuse_skip_layer_normalization , expected_count = 1
203+ )
176204 self .assertEqual (self ._count_op (model , "SkipLayerNormalization" , "com.microsoft" ), 1 )
177205 self .assertEqual (self ._count_op (model , "LayerNormalization" ), 0 )
178206
@@ -183,8 +211,10 @@ def test_skip_ln_post_bias(self):
183211 input_types = [self ._3D , self ._3D , self ._1D , self ._1D , self ._1D ],
184212 output_types = [self ._3D ],
185213 )
186- count = fuse_skip_layer_normalization (model )
187- self .assertGreater (count , 0 )
214+ inputs = self ._make_inputs ("input" , "skip" , "gamma" , "beta" , "bias" )
215+ self ._check_numerical_equivalence (
216+ model , inputs , fuse_skip_layer_normalization , expected_count = 1
217+ )
188218 self .assertEqual (self ._count_op (model , "SkipLayerNormalization" , "com.microsoft" ), 1 )
189219
190220 # ---- Negative tests ----
0 commit comments