diff --git a/nemo/core/optim/mcore_optim.py b/nemo/core/optim/mcore_optim.py index 0d4b524049ca..5103cce2dd84 100644 --- a/nemo/core/optim/mcore_optim.py +++ b/nemo/core/optim/mcore_optim.py @@ -37,6 +37,8 @@ def __init__(self, optim): self.mcore_optimizer = optim self.param_groups = self.mcore_optimizer.param_groups self.state = self.mcore_optimizer.state + self.sharding_type = 'dp_zero_gather_scatter' + # 'fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter' def zero_grad(self, set_to_none: bool = True): """We only need to zero the model related parameters, i.e., @@ -55,8 +57,9 @@ def state_dict(self): def load_state_dict(self, state_dict): self.mcore_optimizer.load_state_dict(state_dict) - def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): - return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs) + def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): + return self.mcore_optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading=False, sharding_type='dp_zero_gather_scatter') def step(self, closure): """Clip gradients (if needed) and step the base optimizer.