Skip to content

Commit c703d8d

Browse files
Create 'default' backend for fallback op implementations; initial CPU nf4 work
1 parent b599401 commit c703d8d

6 files changed

Lines changed: 106 additions & 33 deletions

File tree

bitsandbytes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .backends.cpu import ops as cpu_ops
1414
from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only
15+
from .backends.default import ops as default_ops
1516
from .nn import modules
1617
from .optim import adam
1718

bitsandbytes/_ops.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,6 @@ def _(
3535
return torch.empty(shapeC, device=A.device, dtype=dtype)
3636

3737

38-
@register_kernel("bitsandbytes::int8_scaled_mm", None)
39-
def _(
40-
A: torch.Tensor,
41-
B: torch.Tensor,
42-
row_stats: torch.Tensor,
43-
col_stats: torch.Tensor,
44-
bias: Optional[torch.Tensor] = None,
45-
dtype=torch.float16,
46-
) -> torch.Tensor:
47-
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
48-
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
49-
return out
50-
51-
5238
torch.library.define(
5339
"bitsandbytes::int8_linear_matmul",
5440
"(Tensor A, Tensor B) -> Tensor",

bitsandbytes/backends/cpu/ops.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,19 @@
88
from ..._ops import register_kernel
99
from ...cextension import lib
1010

11+
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
12+
# However, we can overflow if we use this without AVX512_VNNI support.
13+
# This is fixed in torch 2.6+, so we set this as the minimum to be safe.
14+
# For more information: https://github.com/pytorch/pytorch/pull/136942
15+
# TODO(matthewdouglas): aarch64?
16+
if torch.__version__ >= (2, 6):
1117

12-
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
13-
def _(A: torch.Tensor, B: torch.Tensor):
14-
return _int8_linear_matmul_impl(A, B)
15-
16-
17-
@register_kernel("bitsandbytes::int8_linear_matmul.out", "cpu")
18-
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
19-
torch._check(out.dtype == torch.int32)
20-
_int8_linear_matmul_impl(A, B, out)
21-
22-
23-
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
24-
# Naive implementation: perform matmul in fp32
25-
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
26-
if out is not None:
27-
result = out.copy_(result)
28-
return result
18+
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
19+
def _(A: torch.Tensor, B: torch.Tensor):
20+
return torch._int_mm(
21+
A.reshape(-1, A.shape[-1]),
22+
B.t(),
23+
).reshape(*A.shape[:-1], B.shape[0])
2924

3025

3126
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
@@ -92,3 +87,56 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
9287
)
9388

9489
return out
90+
91+
92+
_NF4_QUANT_TABLE = torch.tensor(
93+
[
94+
-1.0,
95+
-0.6961928009986877,
96+
-0.5250730514526367,
97+
-0.39491748809814453,
98+
-0.28444138169288635,
99+
-0.18477343022823334,
100+
-0.09105003625154495,
101+
0.0,
102+
0.07958029955625534,
103+
0.16093020141124725,
104+
0.24611230194568634,
105+
0.33791524171829224,
106+
0.44070982933044434,
107+
0.5626170039176941,
108+
0.7229568362236023,
109+
1.0,
110+
],
111+
dtype=torch.float32,
112+
device="cpu",
113+
)
114+
115+
116+
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
117+
def _(
118+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
119+
) -> Tuple[torch.Tensor, torch.Tensor]:
120+
torch._check_is_size(blocksize)
121+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
122+
123+
n = A.numel()
124+
125+
# TODO: Support when weight matrix is not divisible by blocksize
126+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
127+
128+
# Divide into blocks and normalize
129+
blocks = A.reshape(-1, blocksize)
130+
absmax = blocks.abs().max(dim=1).values.float()
131+
scaled = blocks / absmax.unsqueeze(-1)
132+
133+
# Quantize with the lookup table
134+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
135+
136+
# Pack two quantized values per byte
137+
packed = quantized[::2] << 4 | quantized[1::2]
138+
139+
if quant_storage != torch.uint8:
140+
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
141+
142+
return packed, absmax.float()

bitsandbytes/backends/default/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from ..._ops import register_kernel
6+
7+
8+
@register_kernel("bitsandbytes::int8_scaled_mm", None)
9+
def _(
10+
A: torch.Tensor,
11+
B: torch.Tensor,
12+
row_stats: torch.Tensor,
13+
col_stats: torch.Tensor,
14+
bias: Optional[torch.Tensor] = None,
15+
dtype=torch.float16,
16+
) -> torch.Tensor:
17+
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
18+
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
19+
return out
20+
21+
22+
@register_kernel("bitsandbytes::int8_linear_matmul", None)
23+
def _(A: torch.Tensor, B: torch.Tensor):
24+
return _int8_linear_matmul_impl(A, B)
25+
26+
27+
@register_kernel("bitsandbytes::int8_linear_matmul.out", None)
28+
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
29+
torch._check(out.dtype == torch.int32)
30+
_int8_linear_matmul_impl(A, B, out)
31+
32+
33+
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
34+
# Naive implementation: perform matmul in fp32
35+
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
36+
if out is not None:
37+
result = out.copy_(result)
38+
return result

tests/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ class Test4bitBlockwiseQuantOps:
146146
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
147147
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
148148
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
149-
if device == "cpu":
150-
pytest.skip("CPU implementation is not available")
149+
if device == "cpu" and quant_type != "nf4":
150+
pytest.skip("CPU implementation is only available for nf4")
151151

152152
A = torch.randn(1024, 1024, dtype=dtype, device=device)
153153

0 commit comments

Comments
 (0)