2626MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2727SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
2828SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
29+ CUSTOM_SCALE_FACTOR = 2.0
2930
3031
3132@script ()
@@ -77,7 +78,7 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7778@script ()
7879def _custom_scale_pre_div_sdpa_script (query , key , value ):
7980 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
80- divisor = op .Constant (value_float = 2.0 )
81+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
8182 scaled_query = op .Div (query , divisor )
8283 scaled_key = op .Div (key_transposed , divisor )
8384 attn_score = op .MatMul (scaled_query , scaled_key )
@@ -89,7 +90,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
8990@script ()
9091def _custom_scale_pre_mul_sdpa_script (query , key , value ):
9192 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
92- multiplier = op .Constant (value_float = 0.5 )
93+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
9394 scaled_query = op .Mul (query , multiplier )
9495 scaled_key = op .Mul (key_transposed , multiplier )
9596 attn_score = op .MatMul (scaled_query , scaled_key )
@@ -101,8 +102,8 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
101102@script ()
102103def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
103104 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104- multiplier_q = op .Constant (value_float = 0.5 )
105- multiplier_k = op .Constant (value_float = 0.5 )
105+ multiplier_q = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106+ multiplier_k = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106107 scaled_query = op .Mul (query , multiplier_q )
107108 scaled_key = op .Mul (key_transposed , multiplier_k )
108109 attn_score = op .MatMul (scaled_query , scaled_key )
@@ -114,7 +115,7 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
114115@script ()
115116def _custom_scale_post_div_sdpa_script (query , key , value ):
116117 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
117- divisor = op .Constant (value_float = 0.1 )
118+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
118119 attn_score = op .MatMul (query , key_transposed )
119120 scaled_attn_score = op .Div (attn_score , divisor )
120121 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
@@ -125,7 +126,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
125126@script ()
126127def _custom_scale_post_mul_sdpa_script (query , key , value ):
127128 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
128- multiplier = op .Constant (value_float = 0.125 )
129+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
129130 attn_score = op .MatMul (query , key_transposed )
130131 scaled_attn_score = op .Mul (attn_score , multiplier )
131132 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
@@ -186,7 +187,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
186187@script ()
187188def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
188189 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
189- divisor = op .Constant (value_float = 2.0 )
190+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
190191 scaled_query = op .Div (query , divisor )
191192 scaled_key = op .Div (key_transposed , divisor )
192193 attn_score = op .MatMul (scaled_query , scaled_key )
@@ -199,7 +200,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
199200@script ()
200201def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
201202 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
202- multiplier = op .Constant (value_float = 0.5 )
203+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
203204 scaled_query = op .Mul (query , multiplier )
204205 scaled_key = op .Mul (key_transposed , multiplier )
205206 attn_score = op .MatMul (scaled_query , scaled_key )
@@ -212,7 +213,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
212213@script ()
213214def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
214215 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
215- divisor = op .Constant (value_float = 0.1 )
216+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
216217 attn_score = op .MatMul (query , key_transposed )
217218 scaled_attn_score = op .Div (attn_score , divisor )
218219 masked_attn_score = op .Add (scaled_attn_score , mask )
@@ -224,7 +225,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
224225@script ()
225226def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
226227 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
227- multiplier = op .Constant (value_float = 0.125 )
228+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
228229 attn_score = op .MatMul (query , key_transposed )
229230 scaled_attn_score = op .Mul (attn_score , multiplier )
230231 masked_attn_score = op .Add (scaled_attn_score , mask )
@@ -278,7 +279,10 @@ class TestSDPAFusion(unittest.TestCase):
278279 ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
279280 ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
280281 ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
281- ("_custom_multi_scale_pre_mul_sdpa_script" , _custom_multi_scale_pre_mul_sdpa_script ),
282+ (
283+ "_custom_multi_scale_pre_mul_sdpa_script" ,
284+ _custom_multi_scale_pre_mul_sdpa_script ,
285+ ),
282286 ]
283287 )
284288 def test_sdpa_fusion (self , name , script_func ):
@@ -296,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
296300 op_types = [n .op_type for n in model .graph ]
297301 self .assertIn ("SDPA" , op_types )
298302
303+ # Ensure that the scale of the SDPA node is set correctly
304+ sdpa_node = next (n for n in model .graph if n .op_type == "SDPA" )
305+ self .assertEqual (sdpa_node .op_type , "SDPA" )
306+ self .assertIsNotNone (sdpa_node .attributes .get ("scale" ))
307+
308+ scale_factor = sdpa_node .attributes ["scale" ].value
309+ self .assertIsNotNone (scale_factor )
310+ if "custom" in name :
311+ if "pre" in name :
312+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR )
313+ elif "post" in name :
314+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR )
315+ else :
316+ if "div" in name :
317+ self .assertEqual (scale_factor , SCALE_FACTOR )
318+ elif "mul" in name :
319+ self .assertEqual (scale_factor , MUL_SCALE_FACTOR )
320+
299321 # new_outputs = ort_run("optimized", model, inputs)
300322 # assert_allclose(new_outputs, original_outputs)
301323
0 commit comments