Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions tests/fsdp_state_dict_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@
import bitsandbytes as bnb


def _current_accelerator_type():
if hasattr(torch, "accelerator") and torch.accelerator.is_available():
return str(torch.accelerator.current_accelerator())
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
if torch.cuda.is_available():
return "cuda"
return "cpu"


def _set_device_index(index: int, device_type: str):
if hasattr(torch, "accelerator"):
torch.accelerator.set_device_index(index)
return
if device_type == "cuda":
torch.cuda.set_device(index)
elif device_type == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "set_device"):
torch.xpu.set_device(index)


def _get_device_and_backend():
"""Auto-detect accelerator device and distributed backend."""
device_type = _current_accelerator_type()
backend_map = {"cuda": "nccl", "xpu": "xccl"}
backend = backend_map.get(device_type, "gloo")
return device_type, backend


class SimpleQLoRAModel(nn.Module):
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""

Expand All @@ -33,15 +61,16 @@ def forward(self, x):


def main():
dist.init_process_group(backend="nccl")
device_type, backend = _get_device_and_backend()
dist.init_process_group(backend=backend)
rank = dist.get_rank()
torch.cuda.set_device(rank)
_set_device_index(rank, device_type)

errors = []

for quant_type in ("nf4", "fp4"):
model = SimpleQLoRAModel(quant_type=quant_type)
model = model.to("cuda")
model = model.to(device_type)

# Freeze quantized base weights (as in real QLoRA)
for p in model.base.parameters():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class Test8BitBlockwiseQuantizeFunctional:
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
iters = 100

if device != "cuda":
if device not in ["cuda", "xpu"]:
iters = 10

# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
# This test is slow in our non-cuda/non-xpu implementations, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
Expand Down
8 changes: 2 additions & 6 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_linear4bit_torch_compile_activation_checkpointing(device, quant_type, c
"""
if device == "hpu" and not is_supported_on_hpu(quant_type):
pytest.skip("This configuration is not supported on HPU.")
if device == "cuda" and platform.system() == "Windows":
if platform.system() == "Windows":
Comment thread
jiqing-feng marked this conversation as resolved.
Outdated
pytest.skip("Triton is not officially supported on Windows")
dim = 256
batch_size = 16
Expand Down Expand Up @@ -569,11 +569,7 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
assert w.bnb_quantized is True


@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
@pytest.mark.skipif(
not getattr(torch.distributed, "is_nccl_available", lambda: False)(),
reason="FSDP test requires NCCL backend",
)
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="FSDP requires an accelerator device")
def test_fsdp_state_dict_save_4bit():
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).

Expand Down
Loading