Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling get_model_state_dict/set_model_state_dict requires forward pass for _lazy_init #125170

Closed
mvpatel2000 opened this issue Apr 29, 2024 · 8 comments
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Apr 29, 2024

馃悰 Describe the bug

The new distributed APIs get_model_state_dict/set_model_state_dict require running at least one forward pass in order to call _lazy_init. For example,

                from torch.distributed.fsdp._runtime_utils import _lazy_init
                for module in self.model.modules():
                    if isinstance(module, FSDP):
                        _lazy_init(module, module)
                set_model_state_dict(
                    model=self.model,
                    model_state_dict=state_dict['model'],
                    options=StateDictOptions(
                        full_state_dict=self.fsdp_state_dict_type != 'sharded',
                        strict=strict,
                        cpu_offload=True,
                    ),
                )

I believe get/set_model_state_dict (and maybe get/set_optim_state_dict) should call _lazy_init as well?

Versions

Torch 2.3

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@mvpatel2000
Copy link
Contributor Author

@pytorchbot label "oncall: distributed"

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 29, 2024
@LucasLLC
Copy link
Contributor

@mvpatel2000 I'm a little confused by the issue, does the above code fail or is the issue that _lazy_init is not called on a submoduleof the model?

@mvpatel2000
Copy link
Contributor Author

@mvpatel2000 I'm a little confused by the issue, does the above code fail or is the issue that _lazy_init is not called on a submoduleof the model?

@LucasLLC sorry I was not clear -- I am suggesting _lazy_init should be inside get/set_model/optimizer_state_dict as the error otherwise is quite confusing. _lazy_init a private function, so the above code is not really desired

@fegin
Copy link
Contributor

fegin commented Apr 29, 2024

@mvpatel2000 Can you show the error message? I thought FSDP.state_dict and FSDP.load_state_dict called the _lazy_init. I'm curious what's the error you encountered.

@mvpatel2000
Copy link
Contributor Author

@fegin

/composer/tests/trainer/test_fsdp_checkpoint.py:136: in get_trainer
    trainer = Trainer(
/composer/composer/trainer/trainer.py:1715: in __init__
    self._rng_state = checkpoint.load_checkpoint(
/composer/composer/utils/checkpoint.py:531: in load_checkpoint
    rng_state_dicts = _restore_checkpoint(
/composer/composer/utils/checkpoint.py:989: in _restore_checkpoint
    state.load_state_dict(
/composer/composer/core/state.py:1377: in load_state_dict
    self.load_model_state(
/composer/composer/core/state.py:1238: in load_model_state
    set_model_state_dict(
/usr/lib/python3/dist-packages/torch/distributed/checkpoint/state_dict.py:859: in set_model_state_dict
    _verify_state_dict(model_state_dict, {}, info)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

model_state_dict = {'module.0.bias': tensor([-0.2733,  0.1886]), 'module.0.weight': tensor([[-0.0063,  0.3783],
        [-0.5830, -0.5194]]), 'module.2.weight': tensor([[-0.0150,  0.5597],
        [-0.0618,  0.1881]])}
optim_state_dict = {}
info = _StateDictInfo(full_state_dict=False, cpu_offload=True, ignore_frozen_params=False, keep_submodule_prefixes=True, stri...bias=True)
), FullyShardedDataParallel(
  (_fsdp_wrapped_module): Linear(in_features=2, out_features=2, bias=False)
)])

    def _verify_state_dict(
        model_state_dict: Dict[str, ValueType],
        optim_state_dict: OptimizerStateType,
        info: _StateDictInfo,
    ) -> None:
        # FSDP root must exist otherwise FSDP state_dict will be incorrect.
        has_fsdp_root = False
        for module in info.fsdp_modules:
            fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
            assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
            if fsdp_state._is_root:
                has_fsdp_root = True
                break
        if info.fsdp_modules and not has_fsdp_root:
>           raise RuntimeError("The model has FSDP modules but no FSDP root module exists.")
E           RuntimeError: The model has FSDP modules but no FSDP root module exists.

@fegin
Copy link
Contributor

fegin commented Apr 30, 2024

@mvpatel2000 The issues has been fixed, #121544. Can you check if this PR solves the issue?

@fegin fegin self-assigned this Apr 30, 2024
@mvpatel2000
Copy link
Contributor Author

@fegin yep that looks good to me! It would be nice to include in 2.3.1

@mvpatel2000
Copy link
Contributor Author

@fegin do you think we can backport for #125425?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

No branches or pull requests

3 participants