Skip to content

Commit

Permalink
pass dp_zero_gather_scatter to starded-state-dict
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed May 15, 2024
1 parent fcb657b commit b1df810
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, optim):
self.mcore_optimizer = optim
self.param_groups = self.mcore_optimizer.param_groups
self.state = self.mcore_optimizer.state
self.sharding_type = 'dp_zero_gather_scatter'
# 'fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter'

def zero_grad(self, set_to_none: bool = True):
"""We only need to zero the model related parameters, i.e.,
Expand All @@ -55,8 +57,9 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.mcore_optimizer.load_state_dict(state_dict)

def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs):
return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs)
def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None):
return self.mcore_optimizer.sharded_state_dict(
model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter')

def step(self, closure):
"""Clip gradients (if needed) and step the base optimizer.
Expand Down

0 comments on commit b1df810

Please sign in to comment.