File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -154,12 +154,14 @@ def fill_qmap(self):
154154 def __setstate__ (self , state ):
155155 super ().__setstate__ (state )
156156
157- def load_state_dict (self , state_dict ):
157+ def load_state_dict (self , state_dict , move_to_device = True ):
158158 """Load an optimizer state.
159159
160160 Arguments:
161161 state_dict (`dict`):
162162 An optimizer state (should be returned from a call to `state_dict`) to load.
163+ move_to_device (`bool`, defaults to `True`):
164+ Whether to move the optimizer's state to the device.
163165 """
164166 # deepcopy, to be consistent with module API
165167 state_dict = deepcopy (state_dict )
@@ -196,7 +198,8 @@ def cast(param, value):
196198 elif isinstance (value , dict ):
197199 for k , v in value .items ():
198200 if k in self .non_castable_tensor_keys :
199- value [k ] = v .to (param .device )
201+ if move_to_device :
202+ value [k ] = v .to (param .device )
200203 else :
201204 value [k ] = cast (param , v )
202205
You can’t perform that action at this time.
0 commit comments