Skip to content

Commit

Permalink
Mcore dist opt ckpt fix
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed May 15, 2024
1 parent 6cb618a commit fcb657b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from apex.transformer.pipeline_parallel.utils import get_num_microbatches

from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam

from nemo.core.optim.mcore_optim import McoreDistributedOptimizer
HAVE_APEX = True

except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -294,7 +294,7 @@ def optimizer_sharded_state_dict(self, unsharded_optim_state=None):
key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state')
}

if isinstance(optimizer, MegatronDistributedFusedAdam):
if isinstance(optimizer, MegatronDistributedFusedAdam) or isinstance(optimizer, McoreDistributedOptimizer):
return optimizer.sharded_state_dict(model_sharded_state_dict, unsharded_optim_state)
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
Expand Down

0 comments on commit fcb657b

Please sign in to comment.