Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
sunpengsdu committed Sep 1, 2023
2 parents 74afbb0 + 860de0a commit ad0cddc
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 32 deletions.
2 changes: 2 additions & 0 deletions internlm/core/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .process_group_initializer import (
Initializer_Data,
Initializer_Model,
Initializer_Nettest,
Initializer_Pipeline,
Initializer_Tensor,
Initializer_Zero1,
Expand Down Expand Up @@ -34,6 +35,7 @@
"Initializer_Pipeline",
"Initializer_Data",
"Initializer_Zero1",
"Initializer_Nettest",
"ProcessGroupInitializer",
"Initializer_Model",
"seed",
Expand Down
6 changes: 6 additions & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(self):
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.zero1_parallel_size = -1
self.nettest_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
Expand Down Expand Up @@ -442,6 +443,9 @@ def init_parallel_groups(self):
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)

# the recommended nettest_parallel_size is 32 GPUs
self.nettest_parallel_size = 32

if self.zero1_parallel_size <= 0:
self.zero1_parallel_size = self.data_parallel_size

Expand All @@ -454,6 +458,7 @@ def init_parallel_groups(self):
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.zero1_parallel_size,
self.nettest_parallel_size,
]

# run initialization of different process groups
Expand All @@ -462,6 +467,7 @@ def init_parallel_groups(self):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers:
Expand Down
55 changes: 55 additions & 0 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

import math
from abc import ABC, abstractmethod
from enum import Enum

Expand Down Expand Up @@ -31,6 +32,9 @@ class ParallelMode(Enum):
# zero1 parallel
ZERO1 = "zero1"

# runntime network test
NETTEST = "nettest"


class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
Expand All @@ -52,13 +56,15 @@ def __init__(
pipeline_parallel_size: int,
tensor_parallel_size: int,
zero1_parallel_size: int,
nettest_parallel_size: int,
):
self.rank = rank
self.world_size = world_size
self.data_parallel_size = data_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.zero1_parallel_size = zero1_parallel_size
self.nettest_parallel_size = nettest_parallel_size
super().__init__()

@abstractmethod
Expand Down Expand Up @@ -332,3 +338,52 @@ def init_dist_group(self, use_cpu: bool = False):
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_Nettest(ProcessGroupInitializer):
"""A ProcessGroupInitializer for network test, especailly for NCCL.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
nettest_parallel_size (int): Size of a network test group.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_nettest_group = math.ceil(self.world_size / self.nettest_parallel_size)

def init_dist_group(self, use_cpu: bool = False):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.NETTEST

for i in range(self.num_nettest_group):
ranks = []
for j in range(self.nettest_parallel_size):
rank = i * self.nettest_parallel_size + j
if rank < self.world_size:
ranks.append(rank)
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
35 changes: 20 additions & 15 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def _compute_norm_with_stage(
grads = [self.padding_grad]
params = [self.padding_tensor]

norm = 0
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
Expand Down Expand Up @@ -542,15 +543,15 @@ def step(self, closure=None):
self._param_store.clear_grads_of_previous_reduced_params()

# compute norm for gradients in the last bucket
total_norms = []
total_norms = {}
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}"
total_norms[group_name] = self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)

timer("sync_grad").start()
Expand All @@ -569,7 +570,7 @@ def _step(self, closure=None, norms=None):
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm

if -1 in norms:
if -1 in norms.values():
found_inf = True

loss_scale = float(self.loss_scale.item()) # backup
Expand Down Expand Up @@ -617,15 +618,17 @@ def _step(self, closure=None, norms=None):

# unscale and clip grads
# get the global norm
global_norm_groups = []
global_norm_groups = {}
if self._clip_grad_norm > 0:
for norm in norms:
global_norm_groups.append(norm**0.5)
for group_name, norm in norms.items():
global_norm_groups[group_name] = norm**0.5

# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)

# update the parameters
timer("step").start()
Expand All @@ -652,7 +655,9 @@ def _step(self, closure=None, norms=None):

# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups

def broadcast_params(self):
handles = []
Expand Down
36 changes: 22 additions & 14 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if isinstance(value, dict):
writer.add_scalars(key=key, value=value, step=train_state.step_count)
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)

if update_panel:
# metrics shown with dashboard panels
panel_metrics = {
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
}
for norm_key, norm_value in grad_norm.items():
panel_metrics[norm_key] = norm_value

logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
"{line}",
line=line,
extra=panel_metrics,
)
else:
logger.info(line)
Expand Down

0 comments on commit ad0cddc

Please sign in to comment.