@@ -66,10 +66,6 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6666 "Query scale is not a scalar." ,
6767 query_scale ,
6868 )
69- if not math .isclose (query_scale_value , sqrt_scaling_factor , rel_tol = 1e-3 ):
70- self ._scale = query_scale_value * query_scale_value
71- else :
72- self ._scale = expected_scaling_factor
7369 # Ensure the scaling factor for key is the same as for query
7470 if (key_scale_value := _ir_utils .get_singleton_value (key_scale )) is None :
7571 return check_result .fail (
@@ -81,6 +77,11 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
8177 "Query and key scales are not equal." ,
8278 query_scale ,
8379 )
80+ if not math .isclose (query_scale_value , sqrt_scaling_factor , rel_tol = 1e-3 ):
81+ self ._scale = query_scale_value * query_scale_value
82+ else :
83+ # Pass no scaling factor to SDPA, SDPA will use the default scaling factor
84+ self ._scale = None
8485 else :
8586 # Check if qk_scale is a scalar == expected_scaling_factor)
8687 # If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
@@ -92,7 +93,8 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
9293 if not math .isclose (qk_scale_value , expected_scaling_factor , rel_tol = 1e-3 ):
9394 self ._scale = qk_scale_value
9495 else :
95- self ._scale = expected_scaling_factor
96+ # Pass no scaling factor to SDPA, SDPA will use the default scaling factor
97+ self ._scale = None
9698
9799 # check ranks/shapes
98100
0 commit comments