Skip to content

Commit 786e73a

Browse files
fxmartyiantbutler01
authored andcommitted
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (huggingface#27940)
fix sdpa dispatch
1 parent 743e6e7 commit 786e73a

2 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/transformers/modeling_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ def _autoset_attn_implementation(
12681268
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
12691269
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
12701270
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
1271+
requested_attn_implementation = None
12711272
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
12721273
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
12731274
raise ValueError(
@@ -1284,9 +1285,7 @@ def _autoset_attn_implementation(
12841285
raise ValueError(message + ".")
12851286

12861287
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
1287-
hard_check_only = True
1288-
else:
1289-
hard_check_only = False
1288+
requested_attn_implementation = config._attn_implementation_internal
12901289

12911290
if use_flash_attention_2:
12921291
logger.warning_once(
@@ -1299,13 +1298,15 @@ def _autoset_attn_implementation(
12991298
config,
13001299
torch_dtype=torch_dtype,
13011300
device_map=device_map,
1302-
hard_check_only=hard_check_only,
1301+
hard_check_only=False,
13031302
check_device_map=check_device_map,
13041303
)
1305-
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
1304+
elif requested_attn_implementation in [None, "sdpa"]:
13061305
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
1307-
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
1308-
elif not hard_check_only:
1306+
config = cls._check_and_enable_sdpa(
1307+
config, hard_check_only=False if requested_attn_implementation is None else True
1308+
)
1309+
else:
13091310
config._attn_implementation = "eager"
13101311

13111312
return config

tests/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
is_flax_available,
8484
is_tf_available,
8585
is_torch_fx_available,
86+
is_torch_sdpa_available,
8687
)
8788
from transformers.utils.generic import ModelOutput
8889

@@ -778,7 +779,7 @@ def _create_and_check_torchscript(self, config, inputs_dict):
778779
configs_no_init.torchscript = True
779780
for model_class in self.all_model_classes:
780781
for attn_implementation in ["eager", "sdpa"]:
781-
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
782+
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
782783
continue
783784

784785
configs_no_init._attn_implementation = attn_implementation

0 commit comments

Comments
 (0)