New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't load on rank 0 only with set_optimizer_state_dict
#125177
Comments
@pytorchbot label "oncall: distributed" |
@mvpatel2000 Curious, how many nodes/ranks are you using? Also, is the state_dict/optimizer state_dict DTensor based? I'm also implement the broadcasting feature at this moment. But it will only support DTensor based state_dict/optimizer state_dict. |
This is for multiple GPUs (2 is sufficient). We currently unit test this with ShardedTensors (as this is the default unless a device mesh is passed in), but I believe it's fine if it's only DTensor |
@mvpatel2000 You can try #125339 to see if that PR can help. But it only supports DTensor. |
馃悰 Describe the bug
To avoid CPU OOMs, our training library only loads monolithic checkpoints on rank 0 and broadcasts to all other ranks (as PyTorch checkpointing supports). When migrating to the new distributed APIs,
we hit an error with this approach in the below function:
pytorch/torch/distributed/checkpoint/state_dict.py
Line 582 in ae13c7e
In our code,
optim_state_dict
is only non-None on rank-0 asrank0only
is set to True in PyTorch code.Currently, I need to do:
which is unideal. I think the split function should just not be run if the context manager is rank0only, but I am not sure here
Versions
Pytorch 2.3
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC
The text was updated successfully, but these errors were encountered: