Add move_to_device kwarg to the optimizer's load_state_dict#1344
Add move_to_device kwarg to the optimizer's load_state_dict#1344matthewdouglas merged 1 commit intobitsandbytes-foundation:mainfrom
move_to_device kwarg to the optimizer's load_state_dict#1344Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
hansonw
left a comment
There was a problem hiding this comment.
I encountered this issue as well but when using the paged variants of the optimizers (load_state_dict should be re-creating paged tensors instead of just using .to(param.device)).
My solution (see suggestion below) was to alter the initialization to use self.get_state_buffer instead. It's still kind of orthogonal to this PR (but the intent is similar) - I can submit a separate PR, but curious what the maintainers think.
| if move_to_device: | ||
| value[k] = v.to(param.device) |
There was a problem hiding this comment.
| if move_to_device: | |
| value[k] = v.to(param.device) | |
| buffer = self.get_state_buffer(v, v.dtype) | |
| buffer.copy_(v) | |
| value[k] = buffer |
There was a problem hiding this comment.
Thanks @hansonw! This seems reasonable as a separate PR!
This makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU.
4c6793a to
8b4857a
Compare
…andbytes-foundation#1344) This makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU.
This PR makes it possible to load an optimizer checkpoint without automatically moving the optimizer's state to the GPU.
Some background as to why: I'm keeping the optimizer's state on the CPU to save on VRAM and I manually move it to the GPU as needed. Unfortunately the
load_state_dictwill move all of the optimizer's tensors to whatever device the model's parameters are currently on, which results in an OOM crash. So currently before loading an optimizer checkpoint I have to unnecessarily move my model to the CPU, call the optimizer'sload_state_dict, and then move the model back to the GPU. With this PR I can skip this silly dance.