Skip to content

Commit bcd6d35

Browse files
committed
Rewrite asserts as exceptions
Mechanically rewrite `assert cond[, msg]` in library code as `if not cond: raise AssertionError[(msg)]` so the checks still fire under `python -O`, which strips asserts. Test files are left alone since pytest relies on `assert` for introspection. Closes #1408. Assisted-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2de5ec3 commit bcd6d35

9 files changed

Lines changed: 74 additions & 37 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ def matmul_4bit(
381381
out: Optional[torch.Tensor] = None,
382382
bias: Optional[torch.Tensor] = None,
383383
):
384-
assert quant_state is not None
384+
if quant_state is None:
385+
raise AssertionError
385386
if A.device.type == "cpu":
386387
if getattr(quant_state, "packing_format_for_cpu", False):
387388
out = F.gemv_4bit(A, B, out, state=quant_state)

bitsandbytes/diagnostics/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def sanity_check():
3939
loss.backward()
4040
adam.step()
4141
p2 = p.data.sum().item()
42-
assert p1 != p2
42+
if p1 == p2:
43+
raise AssertionError
4344

4445

4546
def get_package_version(name: str) -> str:

bitsandbytes/functional.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
9999

100100

101101
def prefetch_tensor(A: torch.Tensor, to_cpu=False):
102-
assert A.is_paged, "Only paged tensors can be prefetched!"
102+
if not A.is_paged:
103+
raise AssertionError("Only paged tensors can be prefetched!")
103104
if to_cpu:
104105
deviceid = -1
105106
else:
@@ -218,7 +219,8 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
218219
values = values.sort().values
219220
values /= values.max()
220221

221-
assert values.numel() == 256
222+
if values.numel() != 256:
223+
raise AssertionError
222224

223225
return values
224226

@@ -254,7 +256,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
254256
e = exponent_bits
255257
p = precision_bits
256258
has_sign = 1 if signed else 0
257-
assert e + p == total_bits - has_sign
259+
if e + p != total_bits - has_sign:
260+
raise AssertionError
258261
# the exponent is biased to 2^(e-1) -1 == 0
259262
evalues = []
260263
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
@@ -279,7 +282,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
279282
if signed:
280283
values.append(-value)
281284

282-
assert len(values) == 2**total_bits
285+
if len(values) != 2**total_bits:
286+
raise AssertionError
283287
values.sort()
284288
if total_bits < 8:
285289
gap = 256 - len(values)
@@ -337,7 +341,8 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
337341
data.append(0)
338342
data.append(1.0)
339343

340-
assert len(data) == 2**total_bits
344+
if len(data) != 2**total_bits:
345+
raise AssertionError
341346

342347
gap = 256 - len(data)
343348
for i in range(gap):
@@ -516,7 +521,8 @@ def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> "QuantState
516521
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
517522

518523
qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
519-
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
524+
if not set(qs_dict.keys()).issubset(cls.valid_qs_keys):
525+
raise AssertionError
520526

521527
if "nested_absmax" in qs_dict:
522528
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
@@ -721,7 +727,8 @@ def dequantize_blockwise(
721727
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
722728
"""
723729

724-
assert quant_state is not None or absmax is not None
730+
if quant_state is None and absmax is None:
731+
raise AssertionError
725732
if code is None and quant_state is None:
726733
if "dynamic" not in name2qmap:
727734
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -842,7 +849,8 @@ def get_4bit_type(typename, device=None, blocksize=64):
842849
data = torch.tensor(data, device=device)
843850
data.div_(data.abs().max())
844851

845-
assert data.numel() == 16
852+
if data.numel() != 16:
853+
raise AssertionError
846854

847855
return data
848856

@@ -1009,7 +1017,8 @@ def dequantize_4bit(
10091017
blocksize = 64
10101018

10111019
if quant_state is None:
1012-
assert absmax is not None and out is not None
1020+
if absmax is None or out is None:
1021+
raise AssertionError
10131022

10141023
quant_state = QuantState(
10151024
absmax=absmax,
@@ -1365,7 +1374,8 @@ def igemm(
13651374
ldc = sB[1]
13661375
elif len(sB) == 3:
13671376
# special case
1368-
assert len(sA) == 3
1377+
if len(sA) != 3:
1378+
raise AssertionError
13691379
if not (sA[0] == sB[0] and sA[1] == sB[1]):
13701380
raise ValueError(
13711381
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}",
@@ -1658,10 +1668,13 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
16581668
unpacked_w[::2] = qweight >> 4
16591669
qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K)
16601670
# pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
1661-
assert len(qweight_final.shape) == 2
1671+
if len(qweight_final.shape) != 2:
1672+
raise AssertionError
16621673
N, K = qweight_final.shape[0], qweight_final.shape[1]
1663-
assert N % block_n == 0, "N must be divisible by block_n"
1664-
assert K % 2 == 0, "K must be even"
1674+
if N % block_n != 0:
1675+
raise AssertionError("N must be divisible by block_n")
1676+
if K % 2 != 0:
1677+
raise AssertionError("K must be even")
16651678
BLOCK_N = block_n
16661679
BIT_COUNT = 32 # (=32 low +32 high)
16671680
new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2]
@@ -1706,18 +1719,23 @@ def _convert_weight_packed_for_cpu_inverse(
17061719
qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
17071720
recovered_state: QuantState with partially restored fields (best-effort inverse)
17081721
"""
1709-
assert quant_state.packing_format_for_cpu, "only for packing format"
1710-
assert packed_weight.dtype == torch.uint8
1711-
assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]"
1722+
if not quant_state.packing_format_for_cpu:
1723+
raise AssertionError("only for packing format")
1724+
if packed_weight.dtype != torch.uint8:
1725+
raise AssertionError
1726+
if len(packed_weight.shape) != 2:
1727+
raise AssertionError("packed_weight should be [N, K/2]")
17121728
N, K_half = packed_weight.shape
17131729
K = K_half * 2
17141730

17151731
# 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
17161732
BLOCK_N = block_n
17171733
BIT_COUNT = 32 # (=32 low + 32 high)
17181734

1719-
assert N % BLOCK_N == 0, "N must be divisible by block_n"
1720-
assert K % 2 == 0, "K must be even"
1735+
if N % BLOCK_N != 0:
1736+
raise AssertionError("N must be divisible by block_n")
1737+
if K % 2 != 0:
1738+
raise AssertionError("K must be even")
17211739

17221740
# [N, K/2] -> [-1, 64] (32 low + 32 high)
17231741
packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64]

bitsandbytes/nn/modules.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line
495495

496496
# the quant state got lost when the parameter got converted. This happens for example for fsdp
497497
# since we registered the module, we can recover the state here
498-
assert module.weight.shape[1] == 1
498+
if module.weight.shape[1] != 1:
499+
raise AssertionError
499500
if not isinstance(module.weight, Params4bit):
500501
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
501502
module.weight.quant_state = module.quant_state
@@ -866,8 +867,10 @@ def forward(self, input: Tensor) -> Tensor:
866867
rows = self.weight.data
867868
row_stats = self.weight.SCB
868869

869-
assert rows.shape == (self.num_embeddings, self.embedding_dim)
870-
assert row_stats.shape == (self.num_embeddings,)
870+
if rows.shape != (self.num_embeddings, self.embedding_dim):
871+
raise AssertionError
872+
if row_stats.shape != (self.num_embeddings,):
873+
raise AssertionError
871874

872875
compressed_output = F.embedding(input, rows)
873876
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))
@@ -928,35 +931,40 @@ def __init__(
928931
)
929932

930933
def _forward_with_partial_dequantize(self, input: Tensor):
931-
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
934+
if self.embedding_dim % self.weight.quant_state.blocksize != 0:
935+
raise AssertionError
932936

933937
w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)
934938

935939
output_4bit = torch.nn.functional.embedding(
936940
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
937941
input=input,
938942
).view(-1, 1)
939-
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
943+
if output_4bit.shape != (input.numel() * self.embedding_dim // 2, 1):
944+
raise AssertionError
940945

941946
blocks_per_emb = self.embedding_dim // self.weight.blocksize
942947

943948
absmax = self.weight.quant_state.absmax
944-
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)
949+
if absmax.shape != (self.num_embeddings * blocks_per_emb,):
950+
raise AssertionError
945951

946952
output_absmax = torch.nn.functional.embedding(
947953
weight=absmax.view(self.num_embeddings, blocks_per_emb),
948954
input=input,
949955
).view(
950956
-1,
951957
)
952-
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
958+
if output_absmax.shape != (input.numel() * blocks_per_emb,):
959+
raise AssertionError
953960

954961
output_quant_state = copy.deepcopy(self.weight.quant_state)
955962
output_quant_state.absmax = output_absmax
956963
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))
957964

958965
output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
959-
assert output.shape == (*input.shape, self.embedding_dim)
966+
if output.shape != (*input.shape, self.embedding_dim):
967+
raise AssertionError
960968

961969
return output.to(self.dtype)
962970

bitsandbytes/nn/parametrize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,16 @@ def _parametrized_state_dict_post_hook(
175175
clean_key = f"{prefix}{param_name}"
176176
state_dict[clean_key] = state_dict.pop(original_key)
177177

178-
assert P.is_parametrized(module, param_name)
178+
if not P.is_parametrized(module, param_name):
179+
raise AssertionError
179180

180181
# Find the parametrization, which should have the quantization state.
181182
parametrization: Bnb4bitParametrization = next(
182183
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
183184
)
184185

185-
assert parametrization is not None, "Parametrization not found for the parameter."
186+
if parametrization is None:
187+
raise AssertionError("Parametrization not found for the parameter.")
186188

187189
quant_state = parametrization.quant_state
188190

bitsandbytes/optim/lars.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ def step(self, closure=None):
248248

249249
update_scale = 1.0
250250
if max_unorm > 0.0:
251-
assert p.dtype == torch.float32
251+
if p.dtype != torch.float32:
252+
raise AssertionError
252253
pnorm = torch.norm(p.detach())
253254
unorm = torch.norm(update)
254255
if unorm > max_unorm * pnorm:

bitsandbytes/optim/optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None)
100100
if isinstance(parameters, torch.Tensor):
101101
parameters = [parameters]
102102
if key is not None and value is not None:
103-
assert key_value_dict is None
103+
if key_value_dict is not None:
104+
raise AssertionError
104105
key_value_dict = {key: value}
105106

106107
if key_value_dict is not None:
@@ -286,8 +287,10 @@ def to_gpu(self):
286287
def check_overrides(self):
287288
for module, attr, config in self.mng.module_weight_config_triple:
288289
pmodule = getattr(module, attr)
289-
assert pmodule is not None
290-
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
290+
if pmodule is None:
291+
raise AssertionError
292+
if not (isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)):
293+
raise AssertionError
291294
found = False
292295
for gindex, group in enumerate(self.param_groups):
293296
if found:

bitsandbytes/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
def outlier_hook(module, input):
12-
assert isinstance(module, torch.nn.Linear)
12+
if not isinstance(module, torch.nn.Linear):
13+
raise AssertionError
1314
tracer = OutlierTracer.get_instance()
1415
hvalue = tracer.get_hvalue(module.weight)
1516
if hvalue not in tracer.hvalue2outlier_idx:
@@ -20,7 +21,8 @@ def outlier_hook(module, input):
2021
# assign the current layer the outlier idx found from the weight
2122
# of the previous linear layer
2223
if tracer.outliers[-1].numel() > 0:
23-
assert tracer.outliers[-1].max() < module.weight.shape[1]
24+
if tracer.outliers[-1].max() >= module.weight.shape[1]:
25+
raise AssertionError
2426
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]
2527

2628
else:

check_bnb_install.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
p2 = p.data.sum().item()
1818

19-
assert p1 != p2
19+
if p1 == p2:
20+
raise AssertionError
2021
print("SUCCESS!")
2122
print("Installation was successful!")

0 commit comments

Comments
 (0)