2020import bitsandbytes as bnb
2121
2222
23+ def _current_accelerator_type ():
24+ if hasattr (torch , "accelerator" ) and torch .accelerator .is_available ():
25+ return str (torch .accelerator .current_accelerator ())
26+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
27+ return "xpu"
28+ if torch .cuda .is_available ():
29+ return "cuda"
30+ return "cpu"
31+
32+
33+ def _set_device_index (index : int , device_type : str ):
34+ if hasattr (torch , "accelerator" ):
35+ torch .accelerator .set_device_index (index )
36+ return
37+ if device_type == "cuda" :
38+ torch .cuda .set_device (index )
39+ elif device_type == "xpu" and hasattr (torch , "xpu" ) and hasattr (torch .xpu , "set_device" ):
40+ torch .xpu .set_device (index )
41+
42+
43+ def _get_device_and_backend ():
44+ """Auto-detect accelerator device and distributed backend."""
45+ device_type = _current_accelerator_type ()
46+ backend_map = {"cuda" : "nccl" , "xpu" : "xccl" }
47+ backend = backend_map .get (device_type , "gloo" )
48+ return device_type , backend
49+
50+
2351class SimpleQLoRAModel (nn .Module ):
2452 """Minimal model with a frozen 4-bit base layer and a trainable adapter."""
2553
@@ -33,15 +61,16 @@ def forward(self, x):
3361
3462
3563def main ():
36- dist .init_process_group (backend = "nccl" )
64+ device_type , backend = _get_device_and_backend ()
65+ dist .init_process_group (backend = backend )
3766 rank = dist .get_rank ()
38- torch . cuda . set_device (rank )
67+ _set_device_index (rank , device_type )
3968
4069 errors = []
4170
4271 for quant_type in ("nf4" , "fp4" ):
4372 model = SimpleQLoRAModel (quant_type = quant_type )
44- model = model .to ("cuda" )
73+ model = model .to (device_type )
4574
4675 # Freeze quantized base weights (as in real QLoRA)
4776 for p in model .base .parameters ():
0 commit comments