Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load on rank 0 only with set_optimizer_state_dict #125177

Open
mvpatel2000 opened this issue Apr 29, 2024 · 4 comments
Open

Can't load on rank 0 only with set_optimizer_state_dict #125177

mvpatel2000 opened this issue Apr 29, 2024 · 4 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Apr 29, 2024

馃悰 Describe the bug

To avoid CPU OOMs, our training library only loads monolithic checkpoints on rank 0 and broadcasts to all other ranks (as PyTorch checkpointing supports). When migrating to the new distributed APIs,

                set_optimizer_state_dict(
                    model=self.model,
                    optimizers=optimizer,
                    optim_state_dict=optim_state_dict,
                    options=StateDictOptions(
                        full_state_dict=self.fsdp_state_dict_type == 'full',
                        strict=strict,
                        cpu_offload=True,
                    ),
                )

we hit an error with this approach in the below function:

optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info)

In our code, optim_state_dict is only non-None on rank-0 as rank0only is set to True in PyTorch code.

Currently, I need to do:

optim_state_dict = MagicMock() if optim_state_dict is None else optim_state_dict

which is unideal. I think the split function should just not be run if the context manager is rank0only, but I am not sure here

Versions

Pytorch 2.3

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC

@mvpatel2000
Copy link
Contributor Author

@pytorchbot label "oncall: distributed"

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 29, 2024
@LucasLLC LucasLLC self-assigned this Apr 29, 2024
@fegin
Copy link
Contributor

fegin commented Apr 30, 2024

@mvpatel2000 Curious, how many nodes/ranks are you using? Also, is the state_dict/optimizer state_dict DTensor based? I'm also implement the broadcasting feature at this moment. But it will only support DTensor based state_dict/optimizer state_dict.

@mvpatel2000
Copy link
Contributor Author

@mvpatel2000 Curious, how many nodes/ranks are you using? Also, is the state_dict/optimizer state_dict DTensor based? I'm also implement the broadcasting feature at this moment. But it will only support DTensor based state_dict/optimizer state_dict.

This is for multiple GPUs (2 is sufficient).

We currently unit test this with ShardedTensors (as this is the default unless a device mesh is passed in), but I believe it's fine if it's only DTensor

@fegin
Copy link
Contributor

fegin commented May 3, 2024

@mvpatel2000 You can try #125339 to see if that PR can help. But it only supports DTensor.

@LucasLLC LucasLLC added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants