File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -153,12 +153,14 @@ def fill_qmap(self):
153153 def __setstate__ (self , state ):
154154 super ().__setstate__ (state )
155155
156- def load_state_dict (self , state_dict ):
156+ def load_state_dict (self , state_dict , move_to_device = True ):
157157 """Load an optimizer state.
158158
159159 Arguments:
160160 state_dict (`dict`):
161161 An optimizer state (should be returned from a call to `state_dict`) to load.
162+ move_to_device (`bool`, defaults to `True`):
163+ Whether to move the optimizer's state to the device.
162164 """
163165 # deepcopy, to be consistent with module API
164166 state_dict = deepcopy (state_dict )
@@ -195,7 +197,8 @@ def cast(param, value):
195197 elif isinstance (value , dict ):
196198 for k , v in value .items ():
197199 if k in self .non_castable_tensor_keys :
198- value [k ] = v .to (param .device )
200+ if move_to_device :
201+ value [k ] = v .to (param .device )
199202 else :
200203 value [k ] = cast (param , v )
201204
You can’t perform that action at this time.
0 commit comments