Skip to content

Commit f09812d

Browse files
attempt at conforming to flatten/unflatten protocol
1 parent b4370b8 commit f09812d

1 file changed

Lines changed: 44 additions & 1 deletion

File tree

bitsandbytes/nn/modules.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,49 @@ def to(self, *args, **kwargs):
344344

345345
return new_param
346346

347+
def __tensor_flatten__(self):
348+
"""Return data tensor and non-tensor context"""
349+
ctx = {
350+
"quant_state": self.quant_state,
351+
"blocksize": self.blocksize,
352+
"compress_statistics": self.compress_statistics,
353+
"quant_type": self.quant_type,
354+
"quant_storage": self.quant_storage,
355+
"module": self.module,
356+
"bnb_quantized": self.bnb_quantized,
357+
}
358+
return ["data"], ctx
359+
360+
@staticmethod
361+
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
362+
"""Reconstruct Params4bit from components"""
363+
data = inner_tensors["data"]
364+
return Params4bit(
365+
data,
366+
requires_grad=data.requires_grad,
367+
quant_state=ctx["quant_state"],
368+
blocksize=ctx["blocksize"],
369+
compress_statistics=ctx["compress_statistics"],
370+
quant_type=ctx["quant_type"],
371+
quant_storage=ctx["quant_storage"],
372+
module=ctx["module"],
373+
bnb_quantized=ctx["bnb_quantized"],
374+
)
375+
376+
def detach(self):
377+
"""Create new instance preserving quantization state"""
378+
return type(self)(
379+
self.data.detach(),
380+
requires_grad=self.requires_grad,
381+
quant_state=self.quant_state,
382+
blocksize=self.blocksize,
383+
compress_statistics=self.compress_statistics,
384+
quant_type=self.quant_type,
385+
quant_storage=self.quant_storage,
386+
module=self.module,
387+
bnb_quantized=self.bnb_quantized,
388+
)
389+
347390

348391
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
349392
if getattr(module.weight, "quant_state", None) is not None:
@@ -480,7 +523,7 @@ def forward(self, x: torch.Tensor):
480523

481524
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
482525

483-
return bnb.matmul_4bit(x, self.weight.data, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
526+
return bnb.matmul_4bit(x, self.weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
484527

485528

486529
class LinearFP4(Linear4bit):

0 commit comments

Comments
 (0)