Skip to content

Commit 98f3a34

Browse files
another one
1 parent 01d060d commit 98f3a34

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

  • bitsandbytes/backends/default

bitsandbytes/backends/default/ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,20 @@ def _dequantize_4bit_impl(
277277
A = A.reshape(-1)
278278
# Map nf4 to [-1, 1]
279279
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
280-
n = out_dq.numel()
281280
out_dq[1::2] = A & 0xF
282281
out_dq[::2] = A >> 4
283282
# code is fp32, cast to dtype to avoid the mismatch issue
284283
code = CODE[quant_type].to(dtype).to(A.device)
285284
out_dq = code[out_dq]
286285

286+
# Use the actual output size, not the unpacked size (which may include padding)
287+
n = 1
288+
for s in shape:
289+
n *= s
290+
# Trim any extra elements from padding during quantization
291+
out_dq = out_dq[:n]
292+
287293
# Apply scales
288-
if out_dq.numel() != n:
289-
assert out_dq.numel() == n + 1
290-
out_dq = torch.narrow(out_dq, 0, 0, n)
291294
blocks = n // blocksize
292295
blocks += 1 if n % blocksize > 0 else 0
293296
rem = n % blocksize

0 commit comments

Comments
 (0)