Skip to content

Commit 5292aa4

Browse files
koutematthewdouglas
authored andcommitted
Add move_to_device kwarg to the optimizer's load_state_dict (bitsandbytes-foundation#1344)
This makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU.
1 parent dc0f4c1 commit 5292aa4

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

bitsandbytes/optim/optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@ def fill_qmap(self):
154154
def __setstate__(self, state):
155155
super().__setstate__(state)
156156

157-
def load_state_dict(self, state_dict):
157+
def load_state_dict(self, state_dict, move_to_device=True):
158158
"""Load an optimizer state.
159159
160160
Arguments:
161161
state_dict (`dict`):
162162
An optimizer state (should be returned from a call to `state_dict`) to load.
163+
move_to_device (`bool`, defaults to `True`):
164+
Whether to move the optimizer's state to the device.
163165
"""
164166
# deepcopy, to be consistent with module API
165167
state_dict = deepcopy(state_dict)
@@ -196,7 +198,8 @@ def cast(param, value):
196198
elif isinstance(value, dict):
197199
for k, v in value.items():
198200
if k in self.non_castable_tensor_keys:
199-
value[k] = v.to(param.device)
201+
if move_to_device:
202+
value[k] = v.to(param.device)
200203
else:
201204
value[k] = cast(param, v)
202205

0 commit comments

Comments
 (0)