Skip to content

Commit

Permalink
feat(seed): set global seed for every model initialization (#496)
Browse files Browse the repository at this point in the history
* bind seed

* bind seed
  • Loading branch information
li126com committed Nov 17, 2023
1 parent 679ed3c commit eba2b85
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
4 changes: 4 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_numa = True

logger = get_logger(__file__)
GLOBAL_SEED = 1024


def get_default_parser():
Expand Down Expand Up @@ -543,6 +544,9 @@ def initialize_distributed_env(
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"

global GLOBAL_SEED
GLOBAL_SEED = seed

if args_check:
args_sanity_check()

Expand Down
15 changes: 15 additions & 0 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-

import math
from functools import wraps
from typing import Optional

import torch
Expand All @@ -11,7 +12,9 @@

from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
Expand Down Expand Up @@ -81,6 +84,7 @@ def __init__(
self.use_flash_attn = use_flash_attn

head_dim = hidden_size // num_attention_heads

self.mixer = MHA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
Expand Down Expand Up @@ -410,6 +414,16 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
return hidden_states


def fix_seed(func):
@wraps(func)
def wrapper(*args, **kwargs):
_SEED_MANAGER.reset()
gpc.set_seed(GLOBAL_SEED)
func(*args, **kwargs)

return wrapper


def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
build generic model 1d
Expand All @@ -429,6 +443,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
logger.info(f"The layer sharding is {all_parts}.")

models = []
PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__)

for start, end in parts:
kwargs["num_layers"] = end - start
Expand Down

0 comments on commit eba2b85

Please sign in to comment.