Skip to content

Commit 2de5ec3

Browse files
authored
Enable XPU for blockwise quantization and FSDP tests (#1921)
* enable xpu tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable fsdp tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix backend Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch version accelerator Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * skip fsdp if cpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix windows skip Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable 8bit and fsdp tests for xpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * skip fsdp on Windows Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 3031919 commit 2de5ec3

5 files changed

Lines changed: 48 additions & 19 deletions

File tree

tests/fsdp_state_dict_save.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,34 @@
2020
import bitsandbytes as bnb
2121

2222

23+
def _current_accelerator_type():
24+
if hasattr(torch, "accelerator") and torch.accelerator.is_available():
25+
return str(torch.accelerator.current_accelerator())
26+
if hasattr(torch, "xpu") and torch.xpu.is_available():
27+
return "xpu"
28+
if torch.cuda.is_available():
29+
return "cuda"
30+
return "cpu"
31+
32+
33+
def _set_device_index(index: int, device_type: str):
34+
if hasattr(torch, "accelerator"):
35+
torch.accelerator.set_device_index(index)
36+
return
37+
if device_type == "cuda":
38+
torch.cuda.set_device(index)
39+
elif device_type == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "set_device"):
40+
torch.xpu.set_device(index)
41+
42+
43+
def _get_device_and_backend():
44+
"""Auto-detect accelerator device and distributed backend."""
45+
device_type = _current_accelerator_type()
46+
backend_map = {"cuda": "nccl", "xpu": "xccl"}
47+
backend = backend_map.get(device_type, "gloo")
48+
return device_type, backend
49+
50+
2351
class SimpleQLoRAModel(nn.Module):
2452
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""
2553

@@ -33,15 +61,16 @@ def forward(self, x):
3361

3462

3563
def main():
36-
dist.init_process_group(backend="nccl")
64+
device_type, backend = _get_device_and_backend()
65+
dist.init_process_group(backend=backend)
3766
rank = dist.get_rank()
38-
torch.cuda.set_device(rank)
67+
_set_device_index(rank, device_type)
3968

4069
errors = []
4170

4271
for quant_type in ("nf4", "fp4"):
4372
model = SimpleQLoRAModel(quant_type=quant_type)
44-
model = model.to("cuda")
73+
model = model.to(device_type)
4574

4675
# Freeze quantized base weights (as in real QLoRA)
4776
for p in model.base.parameters():

tests/test_functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ class Test8BitBlockwiseQuantizeFunctional:
9898
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
9999
iters = 100
100100

101-
if device != "cuda":
101+
if device not in ["cuda", "xpu"]:
102102
iters = 10
103103

104-
# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
104+
# This test is slow in our non-cuda/non-xpu implementations, so avoid atypical use cases.
105105
if nested:
106106
pytest.skip("Not a typical use case.")
107107
if blocksize != 256:

tests/test_linear4bit.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,8 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
569569
assert w.bnb_quantized is True
570570

571571

572-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
573-
@pytest.mark.skipif(
574-
not getattr(torch.distributed, "is_nccl_available", lambda: False)(),
575-
reason="FSDP test requires NCCL backend",
576-
)
572+
@pytest.mark.skipif(platform.system() == "Windows", reason="FSDP is not supported on Windows")
573+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="FSDP requires an accelerator device")
577574
def test_fsdp_state_dict_save_4bit():
578575
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).
579576

tests/test_linear8bitlt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ def test_linear_serialization(
172172
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
173173

174174

175-
@pytest.fixture
176-
def linear8bit(requires_cuda):
175+
@pytest.fixture(params=get_available_devices(no_cpu=True))
176+
def linear8bit(request):
177+
device = request.param
177178
linear = torch.nn.Linear(32, 96)
178179
linear_custom = Linear8bitLt(
179180
linear.in_features,
@@ -188,7 +189,7 @@ def linear8bit(requires_cuda):
188189
has_fp16_weights=False,
189190
)
190191
linear_custom.bias = linear.bias
191-
linear_custom = linear_custom.cuda()
192+
linear_custom = linear_custom.to(device)
192193
return linear_custom
193194

194195

tests/test_modules.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -448,34 +448,36 @@ def test_4bit_embedding_warnings(device, caplog):
448448
assert any("inference" in msg for msg in caplog.messages)
449449

450450

451-
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
451+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
452+
def test_4bit_embedding_weight_fsdp_fix(device):
452453
num_embeddings = 64
453454
embedding_dim = 32
454455

455456
module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
456457

457-
module.cuda()
458+
module.to(device)
458459

459460
module.weight.quant_state = None
460461

461-
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
462+
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
462463

463464
module(input_tokens)
464465

465466
assert module.weight.quant_state is not None
466467

467468

468-
def test_4bit_linear_weight_fsdp_fix(requires_cuda):
469+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
470+
def test_4bit_linear_weight_fsdp_fix(device):
469471
inp_size = 64
470472
out_size = 32
471473

472474
module = bnb.nn.Linear4bit(inp_size, out_size)
473475

474-
module.cuda()
476+
module.to(device)
475477

476478
module.weight.quant_state = None
477479

478-
input_tensor = torch.randn((1, inp_size), device="cuda")
480+
input_tensor = torch.randn((1, inp_size), device=device)
479481

480482
module(input_tensor)
481483

0 commit comments

Comments
 (0)