2626MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2727SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
2828SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
29+ CUSTOM_SCALE_FACTOR = 2.0
2930
3031
3132@script ()
@@ -74,6 +75,65 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7475 return attn_output
7576
7677
78+ @script ()
79+ def _custom_scale_pre_div_sdpa_script (query , key , value ):
80+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
81+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
82+ scaled_query = op .Div (query , divisor )
83+ scaled_key = op .Div (key_transposed , divisor )
84+ attn_score = op .MatMul (scaled_query , scaled_key )
85+ attn_weight = op .Softmax (attn_score , axis = - 1 )
86+ attn_output = op .MatMul (attn_weight , value )
87+ return attn_output
88+
89+
90+ @script ()
91+ def _custom_scale_pre_mul_sdpa_script (query , key , value ):
92+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
93+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
94+ scaled_query = op .Mul (query , multiplier )
95+ scaled_key = op .Mul (key_transposed , multiplier )
96+ attn_score = op .MatMul (scaled_query , scaled_key )
97+ attn_weight = op .Softmax (attn_score , axis = - 1 )
98+ attn_output = op .MatMul (attn_weight , value )
99+ return attn_output
100+
101+
102+ @script ()
103+ def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
104+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
105+ multiplier_q = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106+ multiplier_k = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
107+ scaled_query = op .Mul (query , multiplier_q )
108+ scaled_key = op .Mul (key_transposed , multiplier_k )
109+ attn_score = op .MatMul (scaled_query , scaled_key )
110+ attn_weight = op .Softmax (attn_score , axis = - 1 )
111+ attn_output = op .MatMul (attn_weight , value )
112+ return attn_output
113+
114+
115+ @script ()
116+ def _custom_scale_post_div_sdpa_script (query , key , value ):
117+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
118+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
119+ attn_score = op .MatMul (query , key_transposed )
120+ scaled_attn_score = op .Div (attn_score , divisor )
121+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
122+ attn_output = op .MatMul (attn_weight , value )
123+ return attn_output
124+
125+
126+ @script ()
127+ def _custom_scale_post_mul_sdpa_script (query , key , value ):
128+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
129+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
130+ attn_score = op .MatMul (query , key_transposed )
131+ scaled_attn_score = op .Mul (attn_score , multiplier )
132+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
133+ attn_output = op .MatMul (attn_weight , value )
134+ return attn_output
135+
136+
77137@script ()
78138def _masked_pre_div_sdpa_script (query , key , value , mask ):
79139 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -124,6 +184,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124184 return attn_output
125185
126186
187+ @script ()
188+ def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
189+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
190+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
191+ scaled_query = op .Div (query , divisor )
192+ scaled_key = op .Div (key_transposed , divisor )
193+ attn_score = op .MatMul (scaled_query , scaled_key )
194+ masked_attn_score = op .Add (attn_score , mask )
195+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
196+ attn_output = op .MatMul (attn_weight , value )
197+ return attn_output
198+
199+
200+ @script ()
201+ def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
202+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
203+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
204+ scaled_query = op .Mul (query , multiplier )
205+ scaled_key = op .Mul (key_transposed , multiplier )
206+ attn_score = op .MatMul (scaled_query , scaled_key )
207+ masked_attn_score = op .Add (attn_score , mask )
208+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
209+ attn_output = op .MatMul (attn_weight , value )
210+ return attn_output
211+
212+
213+ @script ()
214+ def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
215+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
216+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
217+ attn_score = op .MatMul (query , key_transposed )
218+ scaled_attn_score = op .Div (attn_score , divisor )
219+ masked_attn_score = op .Add (scaled_attn_score , mask )
220+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
221+ attn_output = op .MatMul (attn_weight , value )
222+ return attn_output
223+
224+
225+ @script ()
226+ def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
227+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
228+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
229+ attn_score = op .MatMul (query , key_transposed )
230+ scaled_attn_score = op .Mul (attn_score , multiplier )
231+ masked_attn_score = op .Add (scaled_attn_score , mask )
232+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
233+ attn_output = op .MatMul (attn_weight , value )
234+ return attn_output
235+
236+
127237class SDPATestCase :
128238 def __init__ (self , script_func ):
129239 self .script_func = script_func
@@ -161,6 +271,18 @@ class TestSDPAFusion(unittest.TestCase):
161271 ("pre_mul" , _masked_pre_mul_sdpa_script ),
162272 ("post_div" , _masked_post_div_sdpa_script ),
163273 ("post_mul" , _masked_post_mul_sdpa_script ),
274+ ("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
275+ ("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
276+ ("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
277+ ("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
278+ ("custom_scale_post_mul_masked" , _custom_scale_post_mul_sdpa_script ),
279+ ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
280+ ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
281+ ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
282+ (
283+ "_custom_multi_scale_pre_mul_sdpa_script" ,
284+ _custom_multi_scale_pre_mul_sdpa_script ,
285+ ),
164286 ]
165287 )
166288 def test_sdpa_fusion (self , name , script_func ):
@@ -178,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
178300 op_types = [n .op_type for n in model .graph ]
179301 self .assertIn ("SDPA" , op_types )
180302
303+ # Ensure that the scale of the SDPA node is set correctly
304+ sdpa_node = next (n for n in model .graph if n .op_type == "SDPA" )
305+ self .assertEqual (sdpa_node .op_type , "SDPA" )
306+
307+ if "custom" in name :
308+ self .assertIsNotNone (sdpa_node .attributes .get ("scale" ))
309+ scale_factor = sdpa_node .attributes ["scale" ].value
310+ self .assertIsNotNone (scale_factor )
311+ if "pre" in name :
312+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR )
313+ elif "post" in name :
314+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR )
315+ else :
316+ # These tests are for the default scaling factors, no scale factor is passed to SDPA
317+ # pattern rewriting check functions should be sufficient to check if expected value
318+ # of scale_factor (is =default_scaling_factor)
319+ self .assertIsNone (sdpa_node .attributes .get ("scale" ))
320+
181321 # new_outputs = ort_run("optimized", model, inputs)
182322 # assert_allclose(new_outputs, original_outputs)
183323
0 commit comments