@@ -98,6 +98,19 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
9898 return attn_output
9999
100100
101+ @script ()
102+ def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
103+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104+ multiplier_q = op .Constant (value_float = 0.5 )
105+ multiplier_k = op .Constant (value_float = 0.5 )
106+ scaled_query = op .Mul (query , multiplier_q )
107+ scaled_key = op .Mul (key_transposed , multiplier_k )
108+ attn_score = op .MatMul (scaled_query , scaled_key )
109+ attn_weight = op .Softmax (attn_score , axis = - 1 )
110+ attn_output = op .MatMul (attn_weight , value )
111+ return attn_output
112+
113+
101114@script ()
102115def _custom_scale_post_div_sdpa_script (query , key , value ):
103116 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -265,6 +278,7 @@ class TestSDPAFusion(unittest.TestCase):
265278 ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
266279 ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
267280 ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
281+ (_custom_multi_scale_pre_mul_sdpa_script , _custom_multi_scale_pre_mul_sdpa_script ),
268282 ]
269283 )
270284 def test_sdpa_fusion (self , name , script_func ):
0 commit comments