@@ -344,6 +344,49 @@ def to(self, *args, **kwargs):
344344
345345 return new_param
346346
347+ def __tensor_flatten__ (self ):
348+ """Return data tensor and non-tensor context"""
349+ ctx = {
350+ "quant_state" : self .quant_state ,
351+ "blocksize" : self .blocksize ,
352+ "compress_statistics" : self .compress_statistics ,
353+ "quant_type" : self .quant_type ,
354+ "quant_storage" : self .quant_storage ,
355+ "module" : self .module ,
356+ "bnb_quantized" : self .bnb_quantized ,
357+ }
358+ return ["data" ], ctx
359+
360+ @staticmethod
361+ def __tensor_unflatten__ (inner_tensors , ctx , outer_size , outer_stride ):
362+ """Reconstruct Params4bit from components"""
363+ data = inner_tensors ["data" ]
364+ return Params4bit (
365+ data ,
366+ requires_grad = data .requires_grad ,
367+ quant_state = ctx ["quant_state" ],
368+ blocksize = ctx ["blocksize" ],
369+ compress_statistics = ctx ["compress_statistics" ],
370+ quant_type = ctx ["quant_type" ],
371+ quant_storage = ctx ["quant_storage" ],
372+ module = ctx ["module" ],
373+ bnb_quantized = ctx ["bnb_quantized" ],
374+ )
375+
376+ def detach (self ):
377+ """Create new instance preserving quantization state"""
378+ return type (self )(
379+ self .data .detach (),
380+ requires_grad = self .requires_grad ,
381+ quant_state = self .quant_state ,
382+ blocksize = self .blocksize ,
383+ compress_statistics = self .compress_statistics ,
384+ quant_type = self .quant_type ,
385+ quant_storage = self .quant_storage ,
386+ module = self .module ,
387+ bnb_quantized = self .bnb_quantized ,
388+ )
389+
347390
348391def fix_4bit_weight_quant_state_from_module (module : Union ["Embedding4bit" , "Linear4bit" ]):
349392 if getattr (module .weight , "quant_state" , None ) is not None :
@@ -480,7 +523,7 @@ def forward(self, x: torch.Tensor):
480523
481524 bias = None if self .bias is None else self .bias .to (self .compute_dtype )
482525
483- return bnb .matmul_4bit (x , self .weight . data , bias = bias , quant_state = self .weight .quant_state ).to (inp_dtype )
526+ return bnb .matmul_4bit (x , self .weight , bias = bias , quant_state = self .weight .quant_state ).to (inp_dtype )
484527
485528
486529class LinearFP4 (Linear4bit ):
0 commit comments