Skip to content

Commit 1c96a34

Browse files
fix when A.numel() not divisibel by blocksize
1 parent a5a7f5d commit 1c96a34

2 files changed

Lines changed: 32 additions & 2 deletions

File tree

bitsandbytes/backends/triton/ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ def quantize_4bit(
7676

7777
n = A.numel()
7878

79-
# TODO: Support when weight matrix is not divisible by blocksize
80-
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
79+
# Pad to next multiple of blocksize so the kernel always processes full blocks
80+
remainder = n % blocksize
81+
if remainder != 0:
82+
padding = blocksize - remainder
83+
A = torch.nn.functional.pad(A.view(-1), (0, padding), value=0.0)
84+
n = A.numel()
8185

8286
blocks = -(n // -(blocksize * 2))
8387

tests/test_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,32 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
172172

173173
opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
174174

175+
@pytest.mark.parametrize("device", get_available_devices())
176+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
178+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
179+
def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_type, blocksize):
180+
"""Test quantize/dequantize roundtrip when n_elements is not divisible by blocksize."""
181+
# Shape chosen so numel is NOT divisible by blocksize
182+
shape = (7, blocksize - 1)
183+
A = torch.randn(shape, dtype=dtype, device=device)
184+
storage_dtype = torch.uint8
185+
186+
# Should not raise
187+
packed, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
188+
189+
assert packed.device == A.device
190+
assert absmax.device == A.device
191+
192+
# Dequantize back and verify shape is preserved
193+
out = torch.ops.bitsandbytes.dequantize_4bit(packed, absmax, blocksize, quant_type, shape, dtype)
194+
195+
assert out.shape == shape
196+
assert out.dtype == dtype
197+
198+
# Verify output is finite (no NaN/Inf)
199+
assert torch.isfinite(out).all(), "Dequantized output contains NaN or Inf"
200+
175201
@pytest.mark.parametrize("device", get_available_devices())
176202
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177203
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))

0 commit comments

Comments
 (0)