Skip to content

Commit 5d695a5

Browse files
committed
enable fsdp tests
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 3a98531 commit 5d695a5

2 files changed

Lines changed: 13 additions & 8 deletions

File tree

tests/fsdp_state_dict_save.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
import bitsandbytes as bnb
2121

2222

23+
def _get_device_and_backend():
24+
"""Auto-detect accelerator device and distributed backend."""
25+
device_type = str(torch.accelerator.current_accelerator())
26+
backend_map = {"cuda": "nccl", "xpu": "ccl"}
27+
backend = backend_map.get(device_type, "gloo")
28+
return device_type, backend
29+
30+
2331
class SimpleQLoRAModel(nn.Module):
2432
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""
2533

@@ -33,15 +41,16 @@ def forward(self, x):
3341

3442

3543
def main():
36-
dist.init_process_group(backend="nccl")
44+
device_type, backend = _get_device_and_backend()
45+
dist.init_process_group(backend=backend)
3746
rank = dist.get_rank()
38-
torch.cuda.set_device(rank)
47+
torch.accelerator.set_device_index(rank)
3948

4049
errors = []
4150

4251
for quant_type in ("nf4", "fp4"):
4352
model = SimpleQLoRAModel(quant_type=quant_type)
44-
model = model.to("cuda")
53+
model = model.to(device_type)
4554

4655
# Freeze quantized base weights (as in real QLoRA)
4756
for p in model.base.parameters():

tests/test_linear4bit.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,7 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
569569
assert w.bnb_quantized is True
570570

571571

572-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
573-
@pytest.mark.skipif(
574-
not getattr(torch.distributed, "is_nccl_available", lambda: False)(),
575-
reason="FSDP test requires NCCL backend",
576-
)
572+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="FSDP requires an accelerator device")
577573
def test_fsdp_state_dict_save_4bit():
578574
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).
579575

0 commit comments

Comments
 (0)