@@ -74,6 +74,52 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7474 return attn_output
7575
7676
77+ @script ()
78+ def _custom_scale_pre_div_sdpa_script (query , key , value ):
79+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
80+ divisor = op .Constant (value_float = 2.0 )
81+ scaled_query = op .Div (query , divisor )
82+ scaled_key = op .Div (key_transposed , divisor )
83+ attn_score = op .MatMul (scaled_query , scaled_key )
84+ attn_weight = op .Softmax (attn_score , axis = - 1 )
85+ attn_output = op .MatMul (attn_weight , value )
86+ return attn_output
87+
88+
89+ @script ()
90+ def _custom_scale_pre_mul_sdpa_script (query , key , value ):
91+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
92+ multiplier = op .Constant (value_float = 0.5 )
93+ scaled_query = op .Mul (query , multiplier )
94+ scaled_key = op .Mul (key_transposed , multiplier )
95+ attn_score = op .MatMul (scaled_query , scaled_key )
96+ attn_weight = op .Softmax (attn_score , axis = - 1 )
97+ attn_output = op .MatMul (attn_weight , value )
98+ return attn_output
99+
100+
101+ @script ()
102+ def _custom_scale_post_div_sdpa_script (query , key , value ):
103+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104+ divisor = op .Constant (value_float = 0.1 )
105+ attn_score = op .MatMul (query , key_transposed )
106+ scaled_attn_score = op .Div (attn_score , divisor )
107+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
108+ attn_output = op .MatMul (attn_weight , value )
109+ return attn_output
110+
111+
112+ @script ()
113+ def _custom_scale_post_mul_sdpa_script (query , key , value ):
114+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
115+ multiplier = op .Constant (value_float = 0.125 )
116+ attn_score = op .MatMul (query , key_transposed )
117+ scaled_attn_score = op .Mul (attn_score , multiplier )
118+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
119+ attn_output = op .MatMul (attn_weight , value )
120+ return attn_output
121+
122+
77123@script ()
78124def _masked_pre_div_sdpa_script (query , key , value , mask ):
79125 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -124,6 +170,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124170 return attn_output
125171
126172
173+ @script ()
174+ def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
175+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
176+ divisor = op .Constant (value_float = 2.0 )
177+ scaled_query = op .Div (query , divisor )
178+ scaled_key = op .Div (key_transposed , divisor )
179+ attn_score = op .MatMul (scaled_query , scaled_key )
180+ masked_attn_score = op .Add (attn_score , mask )
181+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
182+ attn_output = op .MatMul (attn_weight , value )
183+ return attn_output
184+
185+
186+ @script ()
187+ def _custom_scale_mul_sdpa_script (query , key , value , mask ):
188+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
189+ multiplier = op .Constant (value_float = 0.5 )
190+ scaled_query = op .Mul (query , multiplier )
191+ scaled_key = op .Mul (key_transposed , multiplier )
192+ attn_score = op .MatMul (scaled_query , scaled_key )
193+ masked_attn_score = op .Add (attn_score , mask )
194+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
195+ attn_output = op .MatMul (attn_weight , value )
196+ return attn_output
197+
198+
199+ @script ()
200+ def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
201+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
202+ divisor = op .Constant (value_float = 0.1 )
203+ attn_score = op .MatMul (query , key_transposed )
204+ scaled_attn_score = op .Div (attn_score , divisor )
205+ masked_attn_score = op .Add (scaled_attn_score , mask )
206+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
207+ attn_output = op .MatMul (attn_weight , value )
208+ return attn_output
209+
210+
211+ @script ()
212+ def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
213+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
214+ multiplier = op .Constant (value_float = 0.125 )
215+ attn_score = op .MatMul (query , key_transposed )
216+ scaled_attn_score = op .Mul (attn_score , multiplier )
217+ masked_attn_score = op .Add (scaled_attn_score , mask )
218+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
219+ attn_output = op .MatMul (attn_weight , value )
220+ return attn_output
221+
222+
127223class SDPATestCase :
128224 def __init__ (self , script_func ):
129225 self .script_func = script_func
@@ -161,6 +257,14 @@ class TestSDPAFusion(unittest.TestCase):
161257 ("pre_mul" , _masked_pre_mul_sdpa_script ),
162258 ("post_div" , _masked_post_div_sdpa_script ),
163259 ("post_mul" , _masked_post_mul_sdpa_script ),
260+ ("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
261+ ("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
262+ ("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
263+ ("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
264+ ("custom_scale_post_mul_masked" , _custom_scale_post_mul_sdpa_script ),
265+ ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
266+ ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
267+ ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
164268 ]
165269 )
166270 def test_sdpa_fusion (self , name , script_func ):
0 commit comments