11import pytest
22
3- from bitsandbytes .cextension import HIP_ENVIRONMENT , get_cuda_bnb_library_path
3+ from bitsandbytes .cextension import BNB_BACKEND , get_cuda_bnb_library_path
44from bitsandbytes .cuda_specs import CUDASpecs
55
66
77@pytest .fixture
88def cuda120_spec () -> CUDASpecs :
9+ """Simulates torch+cuda12.0 and a representative Ampere-class capability."""
910 return CUDASpecs (
1011 cuda_version_string = "120" ,
1112 highest_compute_capability = (8 , 6 ),
1213 cuda_version_tuple = (12 , 0 ),
1314 )
1415
1516
16- @pytest .mark .skipif (HIP_ENVIRONMENT , reason = "this test is not supported on ROCm " )
17+ @pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend " )
1718def test_get_cuda_bnb_library_path (monkeypatch , cuda120_spec ):
19+ """Without overrides, library path uses the detected CUDA 12.0 version."""
20+ monkeypatch .delenv ("BNB_ROCM_VERSION" , raising = False )
1821 monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
1922 assert get_cuda_bnb_library_path (cuda120_spec ).stem == "libbitsandbytes_cuda120"
2023
2124
22- @pytest .mark .skipif (HIP_ENVIRONMENT , reason = "this test is not supported on ROCm " )
25+ @pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend " )
2326def test_get_cuda_bnb_library_path_override (monkeypatch , cuda120_spec , caplog ):
27+ """BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
2428 monkeypatch .setenv ("BNB_CUDA_VERSION" , "110" )
2529 assert get_cuda_bnb_library_path (cuda120_spec ).stem == "libbitsandbytes_cuda110"
2630 assert "BNB_CUDA_VERSION" in caplog .text # did we get the warning?
2731
2832
29- # Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
33+ @pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend" )
34+ def test_get_cuda_bnb_library_path_rejects_rocm_override (monkeypatch , cuda120_spec ):
35+ """BNB_ROCM_VERSION should be rejected on CUDA with a helpful error."""
36+ monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
37+ monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
38+ with pytest .raises (RuntimeError , match = r"BNB_ROCM_VERSION.*detected for CUDA" ):
39+ get_cuda_bnb_library_path (cuda120_spec )
40+
41+
3042@pytest .fixture
3143def rocm70_spec () -> CUDASpecs :
44+ """Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
3245 return CUDASpecs (
33- cuda_version_string = "70" , # from torch.version.hip == "7.0.x"
34- highest_compute_capability = (0 , 0 ), # unused for ROCm library path resolution
46+ cuda_version_string = "70" ,
47+ highest_compute_capability = (0 , 0 ),
3548 cuda_version_tuple = (7 , 0 ),
3649 )
3750
3851
39- @pytest .mark .skipif (not HIP_ENVIRONMENT , reason = "this test is only supported on ROCm" )
52+ @pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend " )
4053def test_get_rocm_bnb_library_path (monkeypatch , rocm70_spec ):
4154 """Without override, library path uses PyTorch's ROCm 7.0 version."""
4255 monkeypatch .delenv ("BNB_ROCM_VERSION" , raising = False )
4356 monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
4457 assert get_cuda_bnb_library_path (rocm70_spec ).stem == "libbitsandbytes_rocm70"
4558
4659
47- @pytest .mark .skipif (not HIP_ENVIRONMENT , reason = "this test is only supported on ROCm" )
60+ @pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend " )
4861def test_get_rocm_bnb_library_path_override (monkeypatch , rocm70_spec , caplog ):
4962 """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
5063 monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
51- monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
52- assert get_cuda_bnb_library_path (rocm70_spec ).stem == "libbitsandbytes_rocm72"
5364 assert "BNB_ROCM_VERSION" in caplog .text
5465
5566
56- @pytest .mark .skipif (not HIP_ENVIRONMENT , reason = "this test is only supported on ROCm" )
67+ @pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend " )
5768def test_get_rocm_bnb_library_path_rejects_cuda_override (monkeypatch , rocm70_spec ):
5869 """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
5970 monkeypatch .delenv ("BNB_ROCM_VERSION" , raising = False )
60- monkeypatch .setenv ("BNB_CUDA_VERSION" , "72 " )
71+ monkeypatch .setenv ("BNB_CUDA_VERSION" , "120 " )
6172 with pytest .raises (RuntimeError , match = r"BNB_CUDA_VERSION.*detected for ROCm" ):
62- get_cuda_bnb_library_path (rocm70_spec )
63-
64-
65- @pytest .mark .skipif (not HIP_ENVIRONMENT , reason = "this test is only supported on ROCm" )
66- def test_get_rocm_bnb_library_path_rocm_override_takes_priority (monkeypatch , rocm70_spec , caplog ):
67- """When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
68- monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
69- monkeypatch .setenv ("BNB_CUDA_VERSION" , "72" )
70- assert get_cuda_bnb_library_path (rocm70_spec ).stem == "libbitsandbytes_rocm72"
71- assert "BNB_ROCM_VERSION" in caplog .text
72- assert "BNB_CUDA_VERSION" not in caplog .text
73+ get_cuda_bnb_library_path (rocm70_spec )
0 commit comments