Skip to content

Commit 6d714a5

Browse files
Embedding4bit and Embedding8bit implementation (#1292)
* Embedding4bit and Embedding8bit implementation * lint * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * saving -> Saving --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
1 parent 4be1883 commit 6d714a5

3 files changed

Lines changed: 355 additions & 13 deletions

File tree

bitsandbytes/nn/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55
from .modules import (
66
Embedding,
7+
Embedding4bit,
8+
Embedding8bit,
9+
EmbeddingFP4,
10+
EmbeddingNF4,
711
Int8Params,
812
Linear4bit,
913
Linear8bitLt,

bitsandbytes/nn/modules.py

Lines changed: 204 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,23 @@ def to(self, *args, **kwargs):
347347
return new_param
348348

349349

350+
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
351+
if getattr(module.weight, "quant_state", None) is not None:
352+
return
353+
354+
if getattr(module, "quant_state", None) is None:
355+
warnings.warn(
356+
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
357+
)
358+
359+
# the quant state got lost when the parameter got converted. This happens for example for fsdp
360+
# since we registered the module, we can recover the state here
361+
assert module.weight.shape[1] == 1
362+
if not isinstance(module.weight, Params4bit):
363+
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
364+
module.weight.quant_state = module.quant_state
365+
366+
350367
class Linear4bit(nn.Linear):
351368
"""
352369
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
@@ -449,22 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
449466
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
450467

451468
def forward(self, x: torch.Tensor):
469+
fix_4bit_weight_quant_state_from_module(self)
470+
452471
# weights are cast automatically as Int8Params, but the bias has to be cast manually
453472
if self.bias is not None and self.bias.dtype != x.dtype:
454473
self.bias.data = self.bias.data.to(x.dtype)
455474

456-
if getattr(self.weight, "quant_state", None) is None:
457-
if getattr(self, "quant_state", None) is not None:
458-
# the quant state got lost when the parameter got converted. This happens for example for fsdp
459-
# since we registered the module, we can recover the state here
460-
assert self.weight.shape[1] == 1
461-
if not isinstance(self.weight, Params4bit):
462-
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
463-
self.weight.quant_state = self.quant_state
464-
else:
465-
print(
466-
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
467-
)
468475
if not self.compute_type_is_set:
469476
self.set_compute_type(x)
470477
self.compute_type_is_set = True
@@ -658,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
658665
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
659666

660667

668+
class Embedding8bit(nn.Embedding):
669+
"""
670+
This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer
671+
672+
Quantization API is similar to Linear8bitLt:
673+
```python
674+
import torch
675+
import torch.nn as nn
676+
677+
from bitsandbytes.nn import Embedding8bit
678+
679+
fp16_module = nn.Embedding(128, 64)
680+
int8_module = Embedding8bit(128, 64)
681+
682+
int8_module.load_state_dict(fp16_module.state_dict())
683+
684+
int8_module = int8_module.to(0) # Quantization happens here
685+
```
686+
"""
687+
688+
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
689+
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
690+
self.dtype = self.weight.data.dtype
691+
692+
self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False)
693+
694+
def _save_to_state_dict(self, destination, prefix, keep_vars):
695+
raise NotImplementedError("Saving Embedding8bit module is not implemented")
696+
697+
def forward(self, input: Tensor) -> Tensor:
698+
if not hasattr(self.weight, "SCB"):
699+
raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.")
700+
701+
rows = self.weight.data
702+
row_stats = self.weight.SCB
703+
704+
assert rows.shape == (self.num_embeddings, self.embedding_dim)
705+
assert row_stats.shape == (self.num_embeddings,)
706+
707+
compressed_output = F.embedding(input, rows)
708+
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))
709+
710+
output = compressed_output * (compressed_output_stats / 127.0)
711+
712+
return output.to(self.dtype)
713+
714+
715+
class Embedding4bit(nn.Embedding):
716+
"""
717+
This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in
718+
[QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.
719+
720+
Quantization API is similar to Linear4bit:
721+
```python
722+
import torch
723+
import torch.nn as nn
724+
725+
from bitsandbytes.nn import Embedding4bit
726+
727+
fp16_module = nn.Embedding(128, 64)
728+
quantized_module = Embedding4bit(128, 64)
729+
730+
quantized_module.load_state_dict(fp16_module.state_dict())
731+
732+
quantized_module = quantized_module.to(0) # Quantization happens here
733+
```
734+
"""
735+
736+
def __init__(
737+
self,
738+
num_embeddings,
739+
embedding_dim,
740+
dtype=None,
741+
quant_type="fp4",
742+
quant_storage=torch.uint8,
743+
device=None,
744+
):
745+
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
746+
self.dtype = self.weight.data.dtype
747+
748+
self.weight = Params4bit(
749+
self.weight.data,
750+
requires_grad=False,
751+
compress_statistics=None,
752+
quant_type=quant_type,
753+
quant_storage=quant_storage,
754+
module=self,
755+
)
756+
757+
blocksize = self.weight.blocksize
758+
759+
if embedding_dim % blocksize != 0:
760+
warnings.warn(
761+
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
762+
"This will lead to slow inference.",
763+
)
764+
765+
def _forward_with_partial_dequantize(self, input: Tensor):
766+
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
767+
768+
w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)
769+
770+
output_4bit = torch.nn.functional.embedding(
771+
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
772+
input=input,
773+
).view(-1, 1)
774+
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
775+
776+
blocks_per_emb = self.embedding_dim // self.weight.blocksize
777+
778+
absmax = self.weight.quant_state.absmax
779+
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)
780+
781+
output_absmax = torch.nn.functional.embedding(
782+
weight=absmax.view(self.num_embeddings, blocks_per_emb),
783+
input=input,
784+
).view(
785+
-1,
786+
)
787+
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
788+
789+
output_quant_state = copy.deepcopy(self.weight.quant_state)
790+
output_quant_state.absmax = output_absmax
791+
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))
792+
793+
output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
794+
assert output.shape == (*input.shape, self.embedding_dim)
795+
796+
return output.to(self.dtype)
797+
798+
def _save_to_state_dict(self, destination, prefix, keep_vars):
799+
raise NotImplementedError("Saving Embedding4bit module is not implemented")
800+
801+
def forward(self, input: Tensor) -> Tensor:
802+
fix_4bit_weight_quant_state_from_module(self)
803+
804+
if self.embedding_dim % self.weight.quant_state.blocksize == 0:
805+
return self._forward_with_partial_dequantize(input)
806+
807+
dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state)
808+
809+
return torch.nn.functional.embedding(
810+
weight=dequantized_weight,
811+
input=input,
812+
).to(self.dtype)
813+
814+
815+
class EmbeddingFP4(Embedding4bit):
816+
def __init__(
817+
self,
818+
num_embeddings,
819+
embedding_dim,
820+
dtype=None,
821+
quant_storage=torch.uint8,
822+
device=None,
823+
):
824+
super().__init__(
825+
num_embeddings,
826+
embedding_dim,
827+
dtype=dtype,
828+
quant_type="fp4",
829+
quant_storage=quant_storage,
830+
device=device,
831+
)
832+
833+
834+
class EmbeddingNF4(Embedding4bit):
835+
def __init__(
836+
self,
837+
num_embeddings,
838+
embedding_dim,
839+
dtype=None,
840+
quant_storage=torch.uint8,
841+
device=None,
842+
):
843+
super().__init__(
844+
num_embeddings,
845+
embedding_dim,
846+
dtype=dtype,
847+
quant_type="nf4",
848+
quant_storage=quant_storage,
849+
device=device,
850+
)
851+
852+
661853
class Linear8bitLt(nn.Linear):
662854
"""
663855
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.

0 commit comments

Comments
 (0)