Skip to content

Commit 6c52fd0

Browse files
committed
possible linear module fix for shared embeddings
1 parent 0b47f94 commit 6c52fd0

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -992,12 +992,10 @@ def _load_from_state_dict(
992992
def init_8bit_state(self):
993993
self.state.CB = self.weight.CB
994994
self.state.SCB = self.weight.SCB
995-
self.weight.CB = None
996-
self.weight.SCB = None
997995

998996
def forward(self, x: torch.Tensor):
999997
self.state.is_training = self.training
1000-
if self.weight.CB is not None:
998+
if self.state.CB is None:
1001999
self.init_8bit_state()
10021000

10031001
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@@ -1069,13 +1067,11 @@ def __init__(
10691067
def init_8bit_state(self):
10701068
self.state.CB = self.weight.CB
10711069
self.state.SCB = self.weight.SCB
1072-
self.weight.CB = None
1073-
self.weight.SCB = None
10741070

10751071
def forward(self, x):
10761072
self.state.is_training = self.training
10771073

1078-
if self.weight.CB is not None:
1074+
if self.state.CB is None:
10791075
self.init_8bit_state()
10801076

10811077
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias

0 commit comments

Comments
 (0)