2929
3030
3131@script ()
32- def _unmasked_pre_div_sdpa_script (query , key , value , mask ):
32+ def _unmasked_pre_div_sdpa_script (query , key , value ):
3333 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
3434 divisor = op .Constant (value_float = SQRT_SCALE_FACTOR )
3535 scaled_query = op .Div (query , divisor )
@@ -41,7 +41,7 @@ def _unmasked_pre_div_sdpa_script(query, key, value, mask):
4141
4242
4343@script ()
44- def _unmasked_pre_mul_sdpa_script (query , key , value , mask ):
44+ def _unmasked_pre_mul_sdpa_script (query , key , value ):
4545 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
4646 multiplier = op .Constant (value_float = SQRT_MUL_SCALE_FACTOR )
4747 scaled_query = op .Mul (query , multiplier )
@@ -53,7 +53,7 @@ def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
5353
5454
5555@script ()
56- def _unmasked_post_div_sdpa_script (query , key , value , mask ):
56+ def _unmasked_post_div_sdpa_script (query , key , value ):
5757 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
5858 divisor = op .Constant (value_float = SCALE_FACTOR )
5959 attn_score = op .MatMul (query , key_transposed )
@@ -64,7 +64,7 @@ def _unmasked_post_div_sdpa_script(query, key, value, mask):
6464
6565
6666@script ()
67- def _unmasked_post_mul_sdpa_script (query , key , value , mask ):
67+ def _unmasked_post_mul_sdpa_script (query , key , value ):
6868 key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
6969 multiplier = op .Constant (value_float = MUL_SCALE_FACTOR )
7070 attn_score = op .MatMul (query , key_transposed )
0 commit comments