Skip to content

Commit

Permalink
Update nemo/core/optim/mcore_optim.py
Browse files Browse the repository at this point in the history
Co-authored-by: mikolajblaz <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa and mikolajblaz committed May 22, 2024
1 parent d7cf7f0 commit 0a9cd71
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def sharded_state_dict(
self, model_sharded_state_dict, optimizer_state_dict=None, is_loading=False, dist_ckpt_parallel_save=False
):
# TODO(@akoumparouli, @mikolajblaz): switch to sharding_type once support for fully_sharded_model_space merged in mcore.
sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
# sharding_type = 'fully_sharded_model_space' if dist_ckpt_parallel_save else 'dp_zero_gather_scatter'
sharding_type = 'dp_zero_gather_scatter'
return self.mcore_optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
)
Expand Down

0 comments on commit 0a9cd71

Please sign in to comment.