Skip to content

Commit e2dc832

Browse files
Remove dead quant_state.dtype mutation in matmul_4bit CPU path (#1917)
The mutation `quant_state.dtype = A.dtype` is unnecessary: MatMul4Bit.forward already casts via `.to(A.dtype)`, and gemv_4bit doesn't read state.dtype. Removing it eliminates the Dynamo graph break on CPU under activation checkpointing, so the regression test no longer needs a CPU skip. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1869cd8 commit e2dc832

2 files changed

Lines changed: 0 additions & 6 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,7 @@ def matmul_4bit(
382382
bias: Optional[torch.Tensor] = None,
383383
):
384384
assert quant_state is not None
385-
# Change dtype to input dtype on CPU
386385
if A.device.type == "cpu":
387-
quant_state.dtype = A.dtype
388-
389386
if getattr(quant_state, "packing_format_for_cpu", False):
390387
out = F.gemv_4bit(A, B, out, state=quant_state)
391388
if bias is not None:

tests/test_linear4bit.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,6 @@ def test_linear4bit_torch_compile_activation_checkpointing(device, quant_type, c
453453
pytest.skip("This configuration is not supported on HPU.")
454454
if device == "cuda" and platform.system() == "Windows":
455455
pytest.skip("Triton is not officially supported on Windows")
456-
if device == "cpu":
457-
pytest.skip("matmul_4bit mutates quant_state.dtype on CPU, causing a separate graph break (#1917)")
458-
459456
dim = 256
460457
batch_size = 16
461458
compute_dtype = torch.bfloat16

0 commit comments

Comments
 (0)