2020import bitsandbytes as bnb
2121
2222
23+ def _current_accelerator_type ():
24+ if hasattr (torch , "accelerator" ) and torch .accelerator .is_available ():
25+ return str (torch .accelerator .current_accelerator ())
26+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
27+ return "xpu"
28+ if torch .cuda .is_available ():
29+ return "cuda"
30+ return "cpu"
31+
32+
33+ def _set_device_index (index : int , device_type : str ):
34+ if hasattr (torch , "accelerator" ):
35+ torch .accelerator .set_device_index (index )
36+ return
37+ if device_type == "cuda" :
38+ torch .cuda .set_device (index )
39+ elif device_type == "xpu" and hasattr (torch , "xpu" ) and hasattr (torch .xpu , "set_device" ):
40+ torch .xpu .set_device (index )
41+
42+
2343def _get_device_and_backend ():
2444 """Auto-detect accelerator device and distributed backend."""
25- device_type = str ( torch . accelerator . current_accelerator () )
45+ device_type = _current_accelerator_type ( )
2646 backend_map = {"cuda" : "nccl" , "xpu" : "xccl" }
2747 backend = backend_map .get (device_type , "gloo" )
2848 return device_type , backend
@@ -44,7 +64,7 @@ def main():
4464 device_type , backend = _get_device_and_backend ()
4565 dist .init_process_group (backend = backend )
4666 rank = dist .get_rank ()
47- torch . accelerator . set_device_index (rank )
67+ _set_device_index (rank , device_type )
4868
4969 errors = []
5070
0 commit comments