File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -992,12 +992,10 @@ def _load_from_state_dict(
992992 def init_8bit_state (self ):
993993 self .state .CB = self .weight .CB
994994 self .state .SCB = self .weight .SCB
995- self .weight .CB = None
996- self .weight .SCB = None
997995
998996 def forward (self , x : torch .Tensor ):
999997 self .state .is_training = self .training
1000- if self .weight .CB is not None :
998+ if self .state .CB is None :
1001999 self .init_8bit_state ()
10021000
10031001 # weights are cast automatically as Int8Params, but the bias has to be cast manually
@@ -1069,13 +1067,11 @@ def __init__(
10691067 def init_8bit_state (self ):
10701068 self .state .CB = self .weight .CB
10711069 self .state .SCB = self .weight .SCB
1072- self .weight .CB = None
1073- self .weight .SCB = None
10741070
10751071 def forward (self , x ):
10761072 self .state .is_training = self .training
10771073
1078- if self .weight .CB is not None :
1074+ if self .state .CB is None :
10791075 self .init_8bit_state ()
10801076
10811077 out = bnb .matmul_mixed (x .half (), self .weight .half (), bias = None , state = self .state ) + self .bias
You can’t perform that action at this time.
0 commit comments