@@ -347,6 +347,23 @@ def to(self, *args, **kwargs):
347347 return new_param
348348
349349
350+ def fix_4bit_weight_quant_state_from_module (module : Union ["Embedding4bit" , "Linear4bit" ]):
351+ if getattr (module .weight , "quant_state" , None ) is not None :
352+ return
353+
354+ if getattr (module , "quant_state" , None ) is None :
355+ warnings .warn (
356+ "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." ,
357+ )
358+
359+ # the quant state got lost when the parameter got converted. This happens for example for fsdp
360+ # since we registered the module, we can recover the state here
361+ assert module .weight .shape [1 ] == 1
362+ if not isinstance (module .weight , Params4bit ):
363+ module .weight = Params4bit (module .weight , quant_storage = module .quant_storage , bnb_quantized = True )
364+ module .weight .quant_state = module .quant_state
365+
366+
350367class Linear4bit (nn .Linear ):
351368 """
352369 This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
@@ -449,22 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
449466 destination [prefix + "weight." + k ] = v if keep_vars else v .detach ()
450467
451468 def forward (self , x : torch .Tensor ):
469+ fix_4bit_weight_quant_state_from_module (self )
470+
452471 # weights are cast automatically as Int8Params, but the bias has to be cast manually
453472 if self .bias is not None and self .bias .dtype != x .dtype :
454473 self .bias .data = self .bias .data .to (x .dtype )
455474
456- if getattr (self .weight , "quant_state" , None ) is None :
457- if getattr (self , "quant_state" , None ) is not None :
458- # the quant state got lost when the parameter got converted. This happens for example for fsdp
459- # since we registered the module, we can recover the state here
460- assert self .weight .shape [1 ] == 1
461- if not isinstance (self .weight , Params4bit ):
462- self .weight = Params4bit (self .weight , quant_storage = self .quant_storage , bnb_quantized = True )
463- self .weight .quant_state = self .quant_state
464- else :
465- print (
466- "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." ,
467- )
468475 if not self .compute_type_is_set :
469476 self .set_compute_type (x )
470477 self .compute_type_is_set = True
@@ -658,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
658665 state_dict [f"{ prefix } weight" ] = undo_layout (weight , tile_indices )
659666
660667
668+ class Embedding8bit (nn .Embedding ):
669+ """
670+ This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer
671+
672+ Quantization API is similar to Linear8bitLt:
673+ ```python
674+ import torch
675+ import torch.nn as nn
676+
677+ from bitsandbytes.nn import Embedding8bit
678+
679+ fp16_module = nn.Embedding(128, 64)
680+ int8_module = Embedding8bit(128, 64)
681+
682+ int8_module.load_state_dict(fp16_module.state_dict())
683+
684+ int8_module = int8_module.to(0) # Quantization happens here
685+ ```
686+ """
687+
688+ def __init__ (self , num_embeddings , embedding_dim , device = None , dtype = None ):
689+ super ().__init__ (num_embeddings , embedding_dim , device = device , dtype = dtype )
690+ self .dtype = self .weight .data .dtype
691+
692+ self .weight = Int8Params (self .weight .data , has_fp16_weights = False , requires_grad = False )
693+
694+ def _save_to_state_dict (self , destination , prefix , keep_vars ):
695+ raise NotImplementedError ("Saving Embedding8bit module is not implemented" )
696+
697+ def forward (self , input : Tensor ) -> Tensor :
698+ if not hasattr (self .weight , "SCB" ):
699+ raise RuntimeError ("Embedding layer is not quantized. Please call .cuda() or .to(device) first." )
700+
701+ rows = self .weight .data
702+ row_stats = self .weight .SCB
703+
704+ assert rows .shape == (self .num_embeddings , self .embedding_dim )
705+ assert row_stats .shape == (self .num_embeddings ,)
706+
707+ compressed_output = F .embedding (input , rows )
708+ compressed_output_stats = F .embedding (input , row_stats .view (self .num_embeddings , 1 ))
709+
710+ output = compressed_output * (compressed_output_stats / 127.0 )
711+
712+ return output .to (self .dtype )
713+
714+
715+ class Embedding4bit (nn .Embedding ):
716+ """
717+ This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in
718+ [QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.
719+
720+ Quantization API is similar to Linear4bit:
721+ ```python
722+ import torch
723+ import torch.nn as nn
724+
725+ from bitsandbytes.nn import Embedding4bit
726+
727+ fp16_module = nn.Embedding(128, 64)
728+ quantized_module = Embedding4bit(128, 64)
729+
730+ quantized_module.load_state_dict(fp16_module.state_dict())
731+
732+ quantized_module = quantized_module.to(0) # Quantization happens here
733+ ```
734+ """
735+
736+ def __init__ (
737+ self ,
738+ num_embeddings ,
739+ embedding_dim ,
740+ dtype = None ,
741+ quant_type = "fp4" ,
742+ quant_storage = torch .uint8 ,
743+ device = None ,
744+ ):
745+ super ().__init__ (num_embeddings , embedding_dim , device = device , dtype = dtype )
746+ self .dtype = self .weight .data .dtype
747+
748+ self .weight = Params4bit (
749+ self .weight .data ,
750+ requires_grad = False ,
751+ compress_statistics = None ,
752+ quant_type = quant_type ,
753+ quant_storage = quant_storage ,
754+ module = self ,
755+ )
756+
757+ blocksize = self .weight .blocksize
758+
759+ if embedding_dim % blocksize != 0 :
760+ warnings .warn (
761+ f"Embedding size { embedding_dim } is not divisible by block size { blocksize } . "
762+ "This will lead to slow inference." ,
763+ )
764+
765+ def _forward_with_partial_dequantize (self , input : Tensor ):
766+ assert self .embedding_dim % self .weight .quant_state .blocksize == 0
767+
768+ w_4bit_uint8 = self .weight .data .view (torch .uint8 ).view (self .num_embeddings * self .embedding_dim // 2 , 1 )
769+
770+ output_4bit = torch .nn .functional .embedding (
771+ weight = w_4bit_uint8 .view (self .num_embeddings , self .embedding_dim // 2 ),
772+ input = input ,
773+ ).view (- 1 , 1 )
774+ assert output_4bit .shape == (input .numel () * self .embedding_dim // 2 , 1 )
775+
776+ blocks_per_emb = self .embedding_dim // self .weight .blocksize
777+
778+ absmax = self .weight .quant_state .absmax
779+ assert absmax .shape == (self .num_embeddings * blocks_per_emb ,)
780+
781+ output_absmax = torch .nn .functional .embedding (
782+ weight = absmax .view (self .num_embeddings , blocks_per_emb ),
783+ input = input ,
784+ ).view (
785+ - 1 ,
786+ )
787+ assert output_absmax .shape == (input .numel () * blocks_per_emb ,)
788+
789+ output_quant_state = copy .deepcopy (self .weight .quant_state )
790+ output_quant_state .absmax = output_absmax
791+ output_quant_state .shape = torch .Size ((* input .shape , self .embedding_dim ))
792+
793+ output = bnb .functional .dequantize_4bit (output_4bit , output_quant_state )
794+ assert output .shape == (* input .shape , self .embedding_dim )
795+
796+ return output .to (self .dtype )
797+
798+ def _save_to_state_dict (self , destination , prefix , keep_vars ):
799+ raise NotImplementedError ("Saving Embedding4bit module is not implemented" )
800+
801+ def forward (self , input : Tensor ) -> Tensor :
802+ fix_4bit_weight_quant_state_from_module (self )
803+
804+ if self .embedding_dim % self .weight .quant_state .blocksize == 0 :
805+ return self ._forward_with_partial_dequantize (input )
806+
807+ dequantized_weight = bnb .functional .dequantize_4bit (self .weight .data , self .weight .quant_state )
808+
809+ return torch .nn .functional .embedding (
810+ weight = dequantized_weight ,
811+ input = input ,
812+ ).to (self .dtype )
813+
814+
815+ class EmbeddingFP4 (Embedding4bit ):
816+ def __init__ (
817+ self ,
818+ num_embeddings ,
819+ embedding_dim ,
820+ dtype = None ,
821+ quant_storage = torch .uint8 ,
822+ device = None ,
823+ ):
824+ super ().__init__ (
825+ num_embeddings ,
826+ embedding_dim ,
827+ dtype = dtype ,
828+ quant_type = "fp4" ,
829+ quant_storage = quant_storage ,
830+ device = device ,
831+ )
832+
833+
834+ class EmbeddingNF4 (Embedding4bit ):
835+ def __init__ (
836+ self ,
837+ num_embeddings ,
838+ embedding_dim ,
839+ dtype = None ,
840+ quant_storage = torch .uint8 ,
841+ device = None ,
842+ ):
843+ super ().__init__ (
844+ num_embeddings ,
845+ embedding_dim ,
846+ dtype = dtype ,
847+ quant_type = "nf4" ,
848+ quant_storage = quant_storage ,
849+ device = device ,
850+ )
851+
852+
661853class Linear8bitLt (nn .Linear ):
662854 """
663855 This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
0 commit comments