2929_HEAD_SIZE = _D // _NUM_HEADS
3030_DEFAULT_SCALE = 1.0 / math .sqrt (_HEAD_SIZE )
3131
32- # Pre-computed constant for use inside @script functions
33- _SCALE_VALUE = 0.25
32+ # Constants for use inside @script functions
33+ _SCALE_TENSOR_025 = ir .tensor (np .array ([0.25 ], dtype = np .float32 ))
34+ _SCALE_TENSOR_2 = ir .tensor (np .array ([2.0 ], dtype = np .float32 ))
3435
3536
3637# --- Script models ---
3738
3839
3940@script ()
40- def _mha_with_scalar_scale (query , key , value , scale ):
41+ def _mha_with_scale_025 (query , key , value ):
42+ scale = op .Constant (value = _SCALE_TENSOR_025 )
43+ scaled_q = op .Mul (query , scale )
44+ return msft_op .MultiHeadAttention (scaled_q , key , value , num_heads = _NUM_HEADS )
45+
46+
47+ @script ()
48+ def _mha_with_scale_2 (query , key , value ):
49+ scale = op .Constant (value = _SCALE_TENSOR_2 )
4150 scaled_q = op .Mul (query , scale )
4251 return msft_op .MultiHeadAttention (scaled_q , key , value , num_heads = _NUM_HEADS )
4352
@@ -74,73 +83,43 @@ def _get_mha_node(self, model: ir.Model) -> ir.Node | None:
7483 return node
7584 return None
7685
77- def _make_scale_constant (self , model : ir .Model , scale_value : float ):
78- """Convert the ``scale`` graph input into a constant initializer."""
79- for node in model .graph :
80- if node .op_type == "Mul" :
81- scale_input = node .inputs [1 ]
82- assert scale_input is not None
83- scale_input .const_value = ir .tensor (np .array ([scale_value ], dtype = np .float32 ))
84- model .graph .inputs .pop ()
85- return
86- raise RuntimeError ("Mul node not found" )
87-
88- def _check_numerical_equivalence (
89- self , model : ir .Model , inputs : dict , scale_value : float , expected_count : int
90- ):
91- # Run original model *before* making scale constant (scale is a graph input)
92- inputs_with_scale = {
93- ** inputs ,
94- "scale" : np .array ([scale_value ], dtype = np .float32 ),
95- }
96- original_output = test_utils .ort_run ("Original" , model , inputs_with_scale )
97- # Now convert scale to constant and apply fusion
98- self ._make_scale_constant (model , scale_value )
86+ _3D = [FLOAT ["B" , "S" , _D ]] * 3
87+ _OUT = [FLOAT ["B" , "S" , _D ]]
88+
89+ def _check_numerical_equivalence (self , model : ir .Model , inputs : dict , expected_count : int ):
90+ original_output = test_utils .ort_run ("Original" , model , inputs )
9991 count = self ._apply (model )
10092 self .assertEqual (count , expected_count )
10193 fused_output = test_utils .ort_run ("Fused" , model , inputs )
10294 test_utils .assert_allclose (original_output , fused_output )
10395
104- # --- Positive tests ---
105-
106- def _build_scale_model (self ):
107- return self ._build (
108- _mha_with_scalar_scale ,
109- input_types = [
110- FLOAT ["B" , "S" , _D ],
111- FLOAT ["B" , "S" , _D ],
112- FLOAT ["B" , "S" , _D ],
113- FLOAT [1 ],
114- ],
115- output_types = [FLOAT ["B" , "S" , _D ]],
116- )
117-
11896 def _make_inputs (self ):
11997 return {
12098 "query" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
12199 "key" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
122100 "value" : np .random .randn (_B , _S , _D ).astype (np .float32 ),
123101 }
124102
103+ # --- Positive tests ---
104+
125105 def test_scalar_scale_fused (self ):
126106 """Mul(query, scalar_constant) before MHA → scale absorbed into attribute."""
127- model = self ._build_scale_model ( )
107+ model = self ._build ( _mha_with_scale_025 , self . _3D , self . _OUT )
128108 inputs = self ._make_inputs ()
129- self ._check_numerical_equivalence (model , inputs , _SCALE_VALUE , expected_count = 1 )
130- # Verify Mul is gone and MHA has scale attribute
109+ self ._check_numerical_equivalence (model , inputs , expected_count = 1 )
131110 self .assertFalse (any (n .op_type == "Mul" for n in model .graph ), "Mul should be removed" )
132111 mha_node = self ._get_mha_node (model )
133112 self .assertIsNotNone (mha_node )
134113 scale_attr = mha_node .attributes .get_float ("scale" , None )
135114 self .assertIsNotNone (scale_attr )
136- expected = _SCALE_VALUE * _DEFAULT_SCALE
115+ expected = 0.25 * _DEFAULT_SCALE
137116 self .assertAlmostEqual (scale_attr , expected , places = 5 )
138117
139118 def test_integer_scale_fused (self ):
140- """Integer scale constant (e.g. 2) → still fused."""
141- model = self ._build_scale_model ( )
119+ """Scale constant of 2.0 → still fused."""
120+ model = self ._build ( _mha_with_scale_2 , self . _3D , self . _OUT )
142121 inputs = self ._make_inputs ()
143- self ._check_numerical_equivalence (model , inputs , 2.0 , expected_count = 1 )
122+ self ._check_numerical_equivalence (model , inputs , expected_count = 1 )
144123 mha_node = self ._get_mha_node (model )
145124 self .assertIsNotNone (mha_node )
146125 scale_attr = mha_node .attributes .get_float ("scale" , None )
@@ -150,45 +129,35 @@ def test_integer_scale_fused(self):
150129
151130 def test_scale_combined_with_existing_scale_attr (self ):
152131 """MHA already has a scale attribute → external scale is multiplied with it."""
153- model = self ._build_scale_model ()
154- # Set existing MHA scale attribute before any ORT run
132+ model = self ._build (_mha_with_scale_025 , self ._3D , self ._OUT )
155133 existing_scale = 0.1
156134 for node in model .graph :
157135 if node .op_type == "MultiHeadAttention" and node .domain == "com.microsoft" :
158136 node .attributes ["scale" ] = ir .AttrFloat32 ("scale" , existing_scale )
159137
160138 inputs = self ._make_inputs ()
161- self ._check_numerical_equivalence (model , inputs , _SCALE_VALUE , expected_count = 1 )
139+ self ._check_numerical_equivalence (model , inputs , expected_count = 1 )
162140 mha_node = self ._get_mha_node (model )
163141 self .assertIsNotNone (mha_node )
164142 scale_attr = mha_node .attributes .get_float ("scale" , None )
165143 self .assertIsNotNone (scale_attr )
166- expected = _SCALE_VALUE * existing_scale
144+ expected = 0.25 * existing_scale
167145 self .assertAlmostEqual (scale_attr , expected , places = 5 )
168146
169147 # --- Negative tests ---
170148
171149 def test_no_mul_no_fusion (self ):
172150 """No Mul before MHA → rule does not match."""
173- model = self ._build (
174- _mha_no_scale ,
175- input_types = [FLOAT ["B" , "S" , _D ], FLOAT ["B" , "S" , _D ], FLOAT ["B" , "S" , _D ]],
176- output_types = [FLOAT ["B" , "S" , _D ]],
177- )
151+ model = self ._build (_mha_no_scale , self ._3D , self ._OUT )
178152 count = self ._apply (model )
179153 self .assertEqual (count , 0 )
180154
181155 def test_dynamic_scale_no_fusion (self ):
182156 """Scale is a non-constant graph input → check rejects."""
183157 model = self ._build (
184158 _mha_with_dynamic_scale ,
185- input_types = [
186- FLOAT ["B" , "S" , _D ],
187- FLOAT ["B" , "S" , _D ],
188- FLOAT ["B" , "S" , _D ],
189- FLOAT [1 ],
190- ],
191- output_types = [FLOAT ["B" , "S" , _D ]],
159+ input_types = [* self ._3D , FLOAT [1 ]],
160+ output_types = self ._OUT ,
192161 )
193162 count = self ._apply (model )
194163 self .assertEqual (count , 0 )
0 commit comments