@@ -13,7 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
1313 self ._use_mask = use_mask
1414 self ._pre_scale = pre_scale
1515 self ._use_mul = use_mul
16- self ._custom_scale = False
16+ self ._scale : float | None = None
1717
1818 def pattern (
1919 self , op , query , key_transposed , value , mask , query_scale , key_scale , qk_scale
@@ -60,14 +60,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6060 # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
6161 # If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
6262 sqrt_scaling_factor = math .sqrt (expected_scaling_factor )
63-
63+ # Calculate the scaling factor for query
6464 if _ir_utils .get_singleton_value (query_scale ) is None :
6565 return check_result .fail (
6666 "Query scale is not a scalar." ,
6767 query_scale ,
6868 )
6969 if not _ir_utils .is_singleton_value (query_scale , sqrt_scaling_factor , rtol = 1e-3 ):
70- self ._custom_scale = True
70+ query_scale_value = _ir_utils .get_singleton_value (query_scale )
71+ self ._scale = query_scale_value * query_scale_value
72+ else :
73+ self ._scale = expected_scaling_factor
74+ # Ensure the scaling factor for key is the same as for query
7175 if _ir_utils .get_singleton_value (key_scale ) is None :
7276 return check_result .fail (
7377 "Key scale is not a scalar." ,
@@ -81,7 +85,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
8185 "Query and key scales are not equal." ,
8286 query_scale ,
8387 )
84- self ._custom_scale = True
8588 else :
8689 # Check if qk_scale is a scalar == expected_scaling_factor)
8790 # If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -91,22 +94,19 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
9194 qk_scale ,
9295 )
9396 if not _ir_utils .is_singleton_value (qk_scale , expected_scaling_factor , rtol = 1e-3 ):
94- self ._custom_scale = True
97+ self ._scale = _ir_utils .get_singleton_value (qk_scale )
98+ else :
99+ self ._scale = expected_scaling_factor
95100
96101 # check ranks/shapes
97102
98103 return check_result
99104
100- def rewrite (
101- self , op , query , key_transposed , value , mask , query_scale , key_scale , qk_scale , ** _
102- ):
105+ def rewrite (self , op , query , key_transposed , value , mask , ** _ ):
103106 sdpa_args = [query , key_transposed , value ]
104107 if self ._use_mask :
105108 sdpa_args .append (mask )
106- if self ._custom_scale :
107- scale = _ir_utils .get_singleton_value (query_scale if self ._pre_scale else qk_scale )
108- return op .SDPA (* sdpa_args , scale = scale , _domain = "ai.onnxruntime.fusion" )
109- return op .SDPA (* sdpa_args , _domain = "ai.onnxruntime.fusion" )
109+ return op .SDPA (* sdpa_args , scale = self ._scale , _domain = "ai.onnxruntime.fusion" )
110110
111111
112112# Rules for SDPA without mask
0 commit comments