Skip to content

Commit dac9de0

Browse files
committed
fix torch version accelerator
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 7925a2a commit dac9de0

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

tests/fsdp_state_dict_save.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,29 @@
2020
import bitsandbytes as bnb
2121

2222

23+
def _current_accelerator_type():
24+
if hasattr(torch, "accelerator") and torch.accelerator.is_available():
25+
return str(torch.accelerator.current_accelerator())
26+
if hasattr(torch, "xpu") and torch.xpu.is_available():
27+
return "xpu"
28+
if torch.cuda.is_available():
29+
return "cuda"
30+
return "cpu"
31+
32+
33+
def _set_device_index(index: int, device_type: str):
34+
if hasattr(torch, "accelerator"):
35+
torch.accelerator.set_device_index(index)
36+
return
37+
if device_type == "cuda":
38+
torch.cuda.set_device(index)
39+
elif device_type == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "set_device"):
40+
torch.xpu.set_device(index)
41+
42+
2343
def _get_device_and_backend():
2444
"""Auto-detect accelerator device and distributed backend."""
25-
device_type = str(torch.accelerator.current_accelerator())
45+
device_type = _current_accelerator_type()
2646
backend_map = {"cuda": "nccl", "xpu": "xccl"}
2747
backend = backend_map.get(device_type, "gloo")
2848
return device_type, backend
@@ -44,7 +64,7 @@ def main():
4464
device_type, backend = _get_device_and_backend()
4565
dist.init_process_group(backend=backend)
4666
rank = dist.get_rank()
47-
torch.accelerator.set_device_index(rank)
67+
_set_device_index(rank, device_type)
4868

4969
errors = []
5070

tests/test_linear4bit.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,14 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
569569
assert w.bnb_quantized is True
570570

571571

572-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="FSDP requires an accelerator device")
572+
@pytest.mark.skipif(
573+
not (
574+
(hasattr(torch, "accelerator") and torch.accelerator.is_available())
575+
or torch.cuda.is_available()
576+
or (hasattr(torch, "xpu") and torch.xpu.is_available())
577+
),
578+
reason="FSDP requires an accelerator device",
579+
)
573580
def test_fsdp_state_dict_save_4bit():
574581
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).
575582

0 commit comments

Comments
 (0)