1818from onnxscript import DOUBLE , FLOAT , FLOAT16 , script
1919from onnxscript import opset18 as op
2020from onnxscript .optimizer import optimize
21+ from onnxscript .rewriter import testing as rewriter_testing
2122from onnxscript .rewriter .rules .fusion ._rms_normalization import fuse_rms_normalization
2223
2324_EPS = ir .tensor (np .array ([1e-6 ], dtype = np .float32 ))
@@ -100,6 +101,8 @@ def _rms_int_input(x, scale):
100101class RmsNormOnnxFusionExtendedTest (unittest .TestCase ):
101102 """Extended tests for RmsNormFusion (rules/fusion variant producing RMSNormalization)."""
102103
104+ _B , _S , _D = 2 , 4 , 16
105+
103106 def _build (self , script_fn , input_types , output_types ) -> ir .Model :
104107 model_proto = script_fn .to_model_proto (
105108 input_types = input_types , output_types = output_types
@@ -111,6 +114,20 @@ def _build(self, script_fn, input_types, output_types) -> ir.Model:
111114 def _count_op (self , model : ir .Model , op_type : str ) -> int :
112115 return sum (1 for n in model .graph if n .op_type == op_type )
113116
117+ def _check_numerical_equivalence (self , model : ir .Model , inputs : dict , expected_count : int ):
118+ """Apply fusion and verify numerical equivalence using ONNX reference impl.
119+
120+ ORT does not yet have a kernel for RMSNormalization, so we use
121+ the ONNX reference implementation for validation.
122+ """
123+ original_proto = ir .serde .serialize_model (model )
124+ count = fuse_rms_normalization (model )
125+ self .assertEqual (count , expected_count )
126+ fused_proto = ir .serde .serialize_model (model )
127+ rewriter_testing .assert_numerically_equal (
128+ original_proto , fused_proto , args = inputs , use_reference = True
129+ )
130+
114131 # --- Positive tests ---
115132
116133 @parameterized .expand (
@@ -123,31 +140,41 @@ def test_mul_order_variants(self, _name, script_fn):
123140 """Both Mul orderings should fuse to RMSNormalization."""
124141 model = self ._build (
125142 script_fn ,
126- input_types = [FLOAT ["B" , "S" , 16 ], FLOAT [16 ]],
127- output_types = [FLOAT ["B" , "S" , 16 ]],
143+ input_types = [FLOAT [self . _B , self . _S , self . _D ], FLOAT [self . _D ]],
144+ output_types = [FLOAT [self . _B , self . _S , self . _D ]],
128145 )
129- count = fuse_rms_normalization (model )
130- self .assertEqual (count , 1 )
146+ inputs = {
147+ "x" : np .random .randn (self ._B , self ._S , self ._D ).astype (np .float32 ),
148+ "scale" : np .random .randn (self ._D ).astype (np .float32 ),
149+ }
150+ self ._check_numerical_equivalence (model , inputs , expected_count = 1 )
131151 self .assertEqual (self ._count_op (model , "RMSNormalization" ), 1 )
132152 self .assertEqual (self ._count_op (model , "Pow" ), 0 )
133153
134154 def test_mixed_precision_cast (self ):
135155 """fp16 input Cast to fp32 for compute, Cast back → fuses."""
136156 model = self ._build (
137157 _rms_mixed_precision ,
138- input_types = [FLOAT16 ["B" , "S" , 16 ], FLOAT16 [16 ]],
139- output_types = [FLOAT16 ["B" , "S" , 16 ]],
158+ input_types = [FLOAT16 [self . _B , self . _S , self . _D ], FLOAT16 [self . _D ]],
159+ output_types = [FLOAT16 [self . _B , self . _S , self . _D ]],
140160 )
141- count = fuse_rms_normalization (model )
142- self .assertEqual (count , 1 )
161+ inputs = {
162+ "x" : np .random .randn (self ._B , self ._S , self ._D ).astype (np .float16 ),
163+ "scale" : np .random .randn (self ._D ).astype (np .float16 ),
164+ }
165+ self ._check_numerical_equivalence (model , inputs , expected_count = 1 )
143166 self .assertEqual (self ._count_op (model , "RMSNormalization" ), 1 )
144167
145168 def test_double_precision (self ):
146- """Double-precision inputs → fuses (double is a valid compute type)."""
169+ """Double-precision inputs → fuses (double is a valid compute type).
170+
171+ Structural check only: ONNX reference impl does not support
172+ RMSNormalization with stash_type=DOUBLE.
173+ """
147174 model = self ._build (
148175 _rms_double ,
149- input_types = [DOUBLE ["B" , "S" , 16 ], DOUBLE [16 ]],
150- output_types = [DOUBLE ["B" , "S" , 16 ]],
176+ input_types = [DOUBLE [self . _B , self . _S , self . _D ], DOUBLE [self . _D ]],
177+ output_types = [DOUBLE [self . _B , self . _S , self . _D ]],
151178 )
152179 count = fuse_rms_normalization (model )
153180 self .assertEqual (count , 1 )
0 commit comments