Skip to content

Commit 8b4857a

Browse files
koutematthewdouglas
authored andcommitted
Add move_to_device kwarg to the optimizer's load_state_dict
This makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU.
1 parent a0da01e commit 8b4857a

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

bitsandbytes/optim/optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,14 @@ def fill_qmap(self):
153153
def __setstate__(self, state):
154154
super().__setstate__(state)
155155

156-
def load_state_dict(self, state_dict):
156+
def load_state_dict(self, state_dict, move_to_device=True):
157157
"""Load an optimizer state.
158158
159159
Arguments:
160160
state_dict (`dict`):
161161
An optimizer state (should be returned from a call to `state_dict`) to load.
162+
move_to_device (`bool`, defaults to `True`):
163+
Whether to move the optimizer's state to the device.
162164
"""
163165
# deepcopy, to be consistent with module API
164166
state_dict = deepcopy(state_dict)
@@ -195,7 +197,8 @@ def cast(param, value):
195197
elif isinstance(value, dict):
196198
for k, v in value.items():
197199
if k in self.non_castable_tensor_keys:
198-
value[k] = v.to(param.device)
200+
if move_to_device:
201+
value[k] = v.to(param.device)
199202
else:
200203
value[k] = cast(param, v)
201204

0 commit comments

Comments
 (0)