Skip to content

Commit a4875fc

Browse files
Bugfix: Load correct nocublaslt library variant when BNB_CUDA_VERSION override is set (#1318)
1 parent 6d714a5 commit a4875fc

2 files changed

Lines changed: 8 additions & 7 deletions

File tree

bitsandbytes/cextension.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import os
2222
from pathlib import Path
23+
import re
2324

2425
import torch
2526

@@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
4445

4546
override_value = os.environ.get("BNB_CUDA_VERSION")
4647
if override_value:
47-
library_name_stem, _, library_name_ext = library_name.rpartition(".")
48-
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`;
49-
# let's remove any trailing numbers:
50-
library_name_stem = library_name_stem.rstrip("0123456789")
51-
# `library_name_stem` will now be e.g. `libbitsandbytes_cuda`;
52-
# let's tack the new version number and the original extension back on.
53-
library_name = f"{library_name_stem}{override_value}.{library_name_ext}"
48+
library_name = re.sub("cuda\d+", f"cuda{override_value}", library_name, count=1)
5449
logger.warning(
5550
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
5651
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"

tests/test_cuda_setup_evaluator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
3333
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
3434

3535

36+
def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
37+
monkeypatch.setenv("BNB_CUDA_VERSION", "125")
38+
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
39+
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
40+
41+
3642
def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
3743
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
3844
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"

0 commit comments

Comments
 (0)