Skip to content

Commit bc6fa69

Browse files
committed
Update CUDA/ROCm setup tests
1 parent 925d83e commit bc6fa69

2 files changed

Lines changed: 39 additions & 32 deletions

File tree

bitsandbytes/cextension.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,34 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
3030
prefix = "rocm" if torch.version.hip else "cuda"
3131
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
3232

33-
override_value = os.environ.get("BNB_CUDA_VERSION")
33+
cuda_override_value = os.environ.get("BNB_CUDA_VERSION")
3434
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")
3535

36-
if rocm_override_value and torch.version.hip:
36+
if rocm_override_value:
3737
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
38+
if torch.version.cuda:
39+
raise RuntimeError(
40+
f"BNB_ROCM_VERSION={rocm_override_value} detected for CUDA!\n"
41+
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
42+
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
43+
)
3844
logger.warning(
3945
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
4046
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
41-
"If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n"
47+
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n"
4248
)
43-
elif override_value:
44-
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
49+
elif cuda_override_value:
50+
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
4551
if torch.version.hip:
4652
raise RuntimeError(
47-
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
53+
f"BNB_CUDA_VERSION={cuda_override_value} detected for ROCm!\n"
4854
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
49-
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
55+
f"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
5056
)
5157
logger.warning(
52-
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
58+
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
5359
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
54-
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
60+
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n"
5561
)
5662

5763
return PACKAGE_DIR / library_name

tests/test_cuda_setup_evaluator.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,73 @@
11
import pytest
22

3-
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
3+
from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path
44
from bitsandbytes.cuda_specs import CUDASpecs
55

66

77
@pytest.fixture
88
def cuda120_spec() -> CUDASpecs:
9+
"""Simulates torch+cuda12.0 and a representative Ampere-class capability."""
910
return CUDASpecs(
1011
cuda_version_string="120",
1112
highest_compute_capability=(8, 6),
1213
cuda_version_tuple=(12, 0),
1314
)
1415

1516

16-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
17+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
1718
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
19+
"""Without overrides, library path uses the detected CUDA 12.0 version."""
20+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
1821
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
1922
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
2023

2124

22-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
25+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
2326
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
27+
"""BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
2428
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
2529
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
2630
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
2731

2832

29-
# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
33+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
34+
def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec):
35+
"""BNB_ROCM_VERSION should be rejected on CUDA with a helpful error."""
36+
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
37+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
38+
with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*detected for CUDA"):
39+
get_cuda_bnb_library_path(cuda120_spec)
40+
41+
3042
@pytest.fixture
3143
def rocm70_spec() -> CUDASpecs:
44+
"""Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
3245
return CUDASpecs(
33-
cuda_version_string="70", # from torch.version.hip == "7.0.x"
34-
highest_compute_capability=(0, 0), # unused for ROCm library path resolution
46+
cuda_version_string="70",
47+
highest_compute_capability=(0, 0),
3548
cuda_version_tuple=(7, 0),
3649
)
3750

3851

39-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
52+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
4053
def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
4154
"""Without override, library path uses PyTorch's ROCm 7.0 version."""
4255
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
4356
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
4457
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70"
4558

4659

47-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
60+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
4861
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
4962
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
5063
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
51-
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
52-
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
5364
assert "BNB_ROCM_VERSION" in caplog.text
5465

5566

56-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
67+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
5768
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
5869
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
5970
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
60-
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
71+
monkeypatch.setenv("BNB_CUDA_VERSION", "120")
6172
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
62-
get_cuda_bnb_library_path(rocm70_spec)
63-
64-
65-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
66-
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
67-
"""When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
68-
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
69-
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
70-
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
71-
assert "BNB_ROCM_VERSION" in caplog.text
72-
assert "BNB_CUDA_VERSION" not in caplog.text
73+
get_cuda_bnb_library_path(rocm70_spec)

0 commit comments

Comments
 (0)