Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alit/optim 8k #9166

Merged
merged 13 commits into from
May 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ model:
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers
activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ model:
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ model:
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block'
activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import nn

from torch import Tensor, nn
from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_layer_spec import (
griffin_mqa_layer_with_transformer_engine_spec,
griffin_recurrent_layer_with_transformer_engine_spec,
)
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.custom_layers.transformer_engine import TENorm, te_checkpoint
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False


def get_griffin_layers(num_layers):
Expand All @@ -41,34 +50,162 @@ def get_griffin_layers(num_layers):


def create_block(
config, layer_spec, layer_idx,
config,
layer_spec,
layer_idx,
):
block = build_module(layer_spec, config,)
block = build_module(
layer_spec,
config,
)
block.layer_number = layer_idx + 1
return block


class GriffinStack(LanguageModule):
def __init__(
self, config: TransformerConfig,
self,
config: TransformerConfig,
):

super().__init__(config)
self.config = config
self.griffin_layers = get_griffin_layers(self.config.num_layers)

self.layers = nn.ModuleList(
[create_block(self.config, layer_spec, layer_idx=i,) for i, layer_spec in enumerate(self.griffin_layers)]
[
create_block(
self.config,
layer_spec,
layer_idx=i,
)
for i, layer_spec in enumerate(self.griffin_layers)
]
)
self.final_layernorm = TENorm(
config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.num_layers = len(self.layers)

def _get_layer(self, layer_number: int):
return self.layers[layer_number]

def _checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
packed_seq_params: PackedSeqParams = None,
):
"""Forward method with activation checkpointing."""

def custom(start: int, end: int):
def custom_forward(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
for index in range(start, end):
layer = self._get_layer(index)
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context

return custom_forward

def checkpoint_handler(forward_func):
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)

if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states, context = checkpoint_handler(custom(l, l + self.config.recompute_num_layers))

l += self.config.recompute_num_layers

elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
recompute_skip_num_layers = 0
for l in range(self.num_layers):
# Skip recomputation when input grad computation is not needed.
# Need to have at least one input tensor with gradient computation
# for re-enterant autograd engine.
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if l >= recompute_skip_num_layers and l < self.config.recompute_num_layers + recompute_skip_num_layers:
hidden_states, context = checkpoint_handler(custom(l, l + 1))
else:
hidden_states, context = custom(l, l + 1)(
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
else:
raise ValueError("Invalid activation recompute method.")

return hidden_states

def forward(self, hidden_states, attention_mask, rotary_pos_emb):

for layer in self.layers:
if (
self.config.recompute_granularity == 'full'
and self.training
and not self.config.activations_checkpoint_recurrent
):
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
)
else:
for layer in self.layers:

hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)

hidden_states = self.final_layernorm(hidden_states)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@
# limitations under the License.

import math

import torch
from megatron.core import tensor_parallel
from megatron.core.jit import jit_fuser
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor, nn
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from megatron.core import tensor_parallel
from megatron.core.jit import jit_fuser
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor, nn

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False

from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_block import GriffinStack

Expand Down Expand Up @@ -142,7 +150,7 @@ def forward(
position_ids: Tensor = None,
attention_mask: Tensor = None,
labels: Tensor = None,
**extra_arg
**extra_arg,
):
if input_ids is None:
input_ids = self.input_tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@

from dataclasses import dataclass
from typing import Union

from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
from torch import Tensor
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False


@dataclass
Expand Down