Skip to content

kwai/Megatron-Kwai

 
 

Repository files navigation

Artifact Evaluation for USENIX ATC '24

This directory contains scripts used to reproduce the results in "Accelerating the Training of Large Language Models using Efficient Activation Rematerialization and Optimal Hybrid Parallelism" that is to appear at USENIX ATC '24. These scripts use OpenMPI, but can be modified for other schedulers as well.

Getting Started Instructions

Get Commit

git clone https://github.com/kwai/Megatron-Kwai.git --branch atc24ae ~/Megatron-Kwai

The referred "Latest Megatron-LM" is the snapshot of NVIDIA/Megatron-LM at Jan 1, 2024, when the commit id was 2bc6cd3. We make minor modifications (~10 lines) on it to ensure compatibility with OpenMPI and for our dataset.

git clone https://github.com/kwai/Megatron-Kwai.git --branch jan-1-2024-main ~/Megatron-LM-Jan-1-2024

Hardware Requirements

Minimum requirements

  • Purpose: To reproduce the accelerating method on a minimum demo.
  • Hardware: One node (i.e., a server) equipped with four NVIDIA A100/A800/H100/H800 80GB GPU cards.

Software Requirements

Software requirements are consistent with NVIDIA's official Megatron-LM. The only additional requirement is TransformerEngine v1.1.0 built with NVTE_WITH_USERBUFFERS=1.

To reproduce the performance, it is suggested to use the same software version as specified in the paper Section 6.1 "Experimental Settings". We also provide a Docker image that complies with all the suggested software.

docker pull yuantailing/megatron-kwai:atc24ae-1.0.0

Dataset

A subset (~100M tokens) of the enwiki dataset is located at /root/dataset within the Docker image. The scripts can be modified for other dataset as well.

Scripts (Minimum Demo)

Train Llama-7B with a context window size of 12,288 using 4 GPUs.

cd ~/Megatron-Kwai/examples/atc24
./1_run_7b_4gpus_baseline.sh
./2_run_7b_4gpus_ours.sh
./3_run_7b_4gpus_origin_megatron.sh
./4_run_7b_4gpus_no_ckpt_oom.sh

Time estimation: Initialization for each script requires less than 1 minute. Each iteration takes approximately 5-10 seconds on H800, or 10-20 seconds on A800. For the first run, an additional 1-3 minutes may be needed to create the dataset index and compile the CUDA extensions.

The expected performance is listed as follows.

File Method A800 throughput / MFU H800 throughput / MFU
1_run_7b_4gpus_baseline.sh baseline 3016 / 47.7% 6639 / 33.1%
2_run_7b_4gpus_ours.sh ours 3803 / 60.1% 8273 / 41.2%
3_run_7b_4gpus_origin_megatron.sh origin Megatron-LM 2847 / 45.0% 6138 / 30.6%
4_run_7b_4gpus_no_ckpt_oom.sh baseline w/o full ckpt OOM OOM

The throughput metric is "Tokens per Second per GPU" printed to the screen.

Comment: The purpose of the minimum demo is to check that our activation rematerialization mechanism works, but pipeline parallelism may not be the best configuration for training Llama-7B s=12k.

Artifact Claims

  1. Comparing script 1 and script 2: Our method achieves a significant performance improvement over the baseline by utilizing offloading and balanced checkpointing.
  2. Comparing script 2 and script 4: Our activation rematerialization mechanism reduces GPU memory consumption.
  3. Comparing script 1 and script 3: The baseline is stronger than the origin Megatron-LM.
  4. Comparing script 1 and script 4: The baseline method raises out-of-memory error (OOM) if full checkpointing is not used, thus CKPT=full is essential when our method is not in use.
  5. Comparing script 2 and script 3: The loss curves are similar, suggesting the correctness of our techniques. Comment: The loss curves cannot be identical due to different initialization of models parameters.

Detailed Instructions

Hardware Requirements

Recommended requirements

  • Purpose: To reproduce the exact performance reported in the paper.
  • Hardware: A cluster consists of 32 nodes. Each node is equipped with eight NVIDIA H800 80GB GPUs interconnected via NVLink. For inter-node communication, each node is outfitted with eight 100 Gbps NICs. Each node is configured with two CPUs and ≥1TB of host memory. Each GPU is connected to a CPU via PCIe 5.0 x16.

Dataset

The dataset should be copied to a directory that is shared by all nodes. Change DATA_PATH in the pretrain_llama.sh accordingly.

Reproduce Table 8

Train Llama-175B with a context window size of 8,192 using 256 GPUs.

cd ~/Megatron-Kwai/examples/atc24
./5_run_175b_256gpus_baseline.sh
./6_run_175b_256gpus_ours.sh
./7_run_175b_256gpus_origin_megatron.sh
./8_run_175b_256gpus_no_ckpt_oom.sh

The expected performance is listed as follows.

File Method H800 throughput / MFU
5_run_175b_256gpus_baseline.sh baseline 299 / 33.4%
6_run_175b_256gpus_ours.sh ours 387 / 43.2%
7_run_175b_256gpus_origin_megatron.sh origin Megatron-LM 278 / 31.0%
8_run_175b_256gpus_no_ckpt_oom.sh baseline w/o full ckpt OOM

Note: Different clusters may require different MPI parameters. Please update the HOSTFILE based on the cluster's node configuration and CLUSTER_MPI_ARGS according to the specific MPI settings required.

Debug tips: If you encounter problems running multi-node scripts, try running the official examples from the origin Megatron-LM on multiple nodes first, ensuring that all types of parallelism -- such as Tensor Parallelism (TP), Context Parallelism (CP), Pipeline Parallelism (PP), and Data Parallelism (DP) -- are enabled. This preliminary step will help verify that the MPI arguments are correctly configured. Once that's done, the script ./7_run_175b_256gpus_origin_megatron.sh should execute without issues, and so do other scripts. Furthermore, for the "TP overlap" feature, more precise configuration of MPI arguments is required. As a debugging step, you can temporarily remove the TP_OVERLAP_ARGS and see if the issue is resolved.

To reproduce the results for other rows in Table 8 (Section 6.4):

  • Change source ./llama-175b to source ./llama-65b or source ./llama2-70b to apply the respective models.
  • Change SEQ_LENGTH=8192 variable to other sequence lengths as needed.
  • Change the variables TP, CP, PP, PP_l, CKPT, and OFFLOAD_ALPHA according to the values of $t$, $c$, $p$, $l$, ckpt, and $\alpha$ as listed in the table.

Reproduce Figure 7

To reproduce the results for Figure 7 (Section 6.2):

  • Run ./9_run_65b_256gpus_verify_memory.sh, and observe the GPU memory usage after the 2nd iteration.
  • Note 1: If the GPU cards are not NVIDIA H800, the observed memory usage may vary slightly, due to the different behavior of PyTorch across GPU types.
  • Note 2: max_memory_allocated can be exactly reproduced. Other memory usage, including max_memory_reserved and memory_info.used, may exhibit slight variations even when executing the same script in multiple times.
  • Note 3: Peak memory usage should be observed after the 2nd itereation because some optimizer states are initialized at the end of 1st iteration.

The exact values of max_memory_allocated are listed below.

$\alpha$ ckpt=no max_memory_allocated ckpt=ours max_memory_allocated
0.00 OOM 65330982400
0.10 OOM 62128631296
0.20 OOM 58907405824
0.30 72762730496 55554059776
0.40 66951784960 52089073152
0.50 61136119808 48604098048
0.60 55733856256 45722611200
0.70 50331592704 42832735744
0.80 45824813056 38925740032
0.90 39527328256 36144916480
1.00 34491172352 32462145024

Reproduce Figure 9

To reproduce the results for Figure 9 (Section 6.5):

  • Change source ./llama-175b and modify SEQ_LENGTH to apply the respective models.
  • The values of NUM_GPUS, GLOBAL_BATCH_SIZE, TP, CP, PP, PP_l, CKPT, and OFFLOAD_ALPHA used in the experiments are provided in the following table.
model SEQ_LENGTH NUM_GPUS $B$ $t$ $c$ $p$ $l$ ckpt $\alpha$ tokens per second
Llama-175B 4096 256 256 2 2 16 1 no 0.522 97577
Llama-175B 4096 240 240 4 1 12 1 no 0.276 85924
Llama-175B 4096 192 240 2 2 24 1 no 0.410 76130
Llama-175B 4096 160 240 4 1 8 2 ours 0.334 59169
Llama-175B 4096 144 252 4 1 6 2 ours 0.746 56637
Llama-175B 4096 128 256 2 1 16 1 ours 0.751 54968
Llama-175B 4096 120 270 4 1 6 2 ours 0.851 45761
Llama-175B 4096 96 264 4 1 12 2 ours 0.324 41048
Llama-175B 4096 64 272 4 1 8 1 ours 0.979 27708
Llama-175B 4096 48 264 4 1 12 1 ours 0.999 20685
Llama-65B 4096 256 256 2 1 8 2 no 0.355 233981
Llama-65B 4096 240 240 2 1 10 2 no 0.302 219415
Llama-65B 4096 200 250 4 1 5 4 no 0.000 179819
Llama-65B 4096 192 240 4 1 8 2 no 0.000 180303
Llama-65B 4096 160 240 2 1 10 2 no 0.331 158984
Llama-65B 4096 128 256 2 1 8 2 ours 0.000 132357
Llama-65B 4096 120 240 2 1 5 2 ours 0.383 121488
Llama-65B 4096 112 252 4 1 4 4 no 0.000 99046
Llama-65B 4096 96 240 2 1 8 2 ours 0.035 96758
Llama-65B 4096 80 240 2 1 10 2 ours 0.000 87396
Llama-65B 4096 64 256 2 1 8 2 ours 0.162 71884
Llama-65B 4096 56 252 2 1 4 2 ours 0.945 52219
Llama-65B 4096 48 264 2 1 8 2 ours 0.290 51124
Llama-65B 4096 40 260 2 1 5 2 ours 0.815 45716
Llama-65B 4096 32 272 2 1 8 2 ours 0.544 36740
Llama2-70B 8192 256 256 2 4 8 2 no 0.000 229338
Llama2-70B 8192 240 270 2 4 10 2 no 0.000 217316
Llama2-70B 8192 224 252 2 4 4 4 ours 0.352 192634
Llama2-70B 8192 200 250 2 2 5 2 ours 0.368 182038
Llama2-70B 8192 192 264 2 4 8 2 no 0.000 177032
Llama2-70B 8192 160 260 2 4 10 2 no 0.000 148456
Llama2-70B 8192 144 252 2 2 4 2 ours 0.755 124599
Llama2-70B 8192 128 256 2 2 8 2 ours 0.012 122265
Llama2-70B 8192 120 270 2 2 5 2 ours 0.460 114817
Llama2-70B 8192 112 252 2 2 4 2 ours 0.811 100242
Llama2-70B 8192 96 264 2 2 8 2 ours 0.080 91488
Llama2-70B 8192 80 260 2 2 10 2 ours 0.024 78698
Llama2-70B 8192 64 256 2 1 8 2 ours 0.654 65302
Llama2-70B 8192 48 264 2 1 8 2 ours 0.722 47755
Llama2-70B 8192 40 260 2 1 10 2 ours 0.713 41182
Llama2-70B 8192 32 272 2 1 8 2 ours 0.858 33262

Artifact Claims

  1. All the memory usage values in Figure 7 (Section 6.2) are reproduced.
  2. All the "Throughput / MFU" values listed in Table 8 (Section 6.4) are reproduced.
  3. All the "achieved throughput" values in Figure 9 (Section 6.5) are reproduced.

Megatron-LM

Megatron (1, 2, and 3) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor, sequence, and pipeline), and multi-node pre-training of transformer based models such as GPT, BERT, and T5 using mixed precision.

Below are some of the projects where we have directly used Megatron:

Megatron is also used in NeMo Megatron, a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters.

Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage NVIDIA's Selene supercomputer to perform scaling studies and use up to 3072 A100 GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The graph below shows that we scale nearly linear up to 1 trillion parameter models running on 3072 GPUs. Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.

Scaling Graph

The following table shows both model (MFU) and hardware (HFU) FLOPs utilization for select configurations up to 1T parameters (see our paper for a description of how these are calculated). As the model size increases, we achieve better GPU utilization and for the one trillion parameter model, we reach a MFU and HFU of 56.3% and 57.0%, respectively. Note that these numbers are also measured on benchmark runs and in this case are measured using a data parallel size of one. Data parallelism introduces some overhead due to the gradient all-reduce required between the data parallel groups. However, for large transformer models, this overhead is not large and can almost entirely eliminted by overlapping the gradient all-reduce with backpropagation.

Model Size Model FLOPs Utilization Hardware FLOPs Utilization
22B 41.5% 43.7%
175B 51.4% 52.8%
530B 56.0% 57.0%
1T 56.3% 57.0%

Contents

Setup

We strongly recommend using the latest release of NGC's PyTorch container with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA APEX releases. Data preprocessing requires NLTK, though this is not required for training, evaluation, or downstream tasks.

You can launch an instance of the PyTorch container and mount Megatron, your dataset, and checkpoints with the following Docker commands:

docker pull nvcr.io/nvidia/pytorch:xx.xx-py3
docker run --gpus all -it --rm -v /path/to/megatron:/workspace/megatron -v /path/to/dataset:/workspace/dataset -v /path/to/checkpoints:/workspace/checkpoints nvcr.io/nvidia/pytorch:xx.xx-py3

Downloading Checkpoints

We have provided pretrained BERT-345M and GPT-345M checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first sign up for and setup the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the NGC documentation.

Alternatively, you can directly download the checkpoints using:

BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip
BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip
GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip

The models require vocabulary files to run. The BERT WordPiece vocab file can be extracted from Google's pretrained BERT models: uncased, cased. The GPT vocab file and merge table can be downloaded directly.

Usage

After installation, there are several possible workflows. The most comprehensive is:

  1. Data preprocessing
  2. Pretraining
  3. Finetuning (Optional for zero-shot tasks)
  4. Downstream task evaluation or text generation

However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above.

We've provided several scripts for pretraining both BERT and GPT in examples directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation.

Training

Data Preprocessing

The training data requires preprocessing. First, place your training data in a loose json format, with one json containing a text sample per line. For example:

{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}

The name of the text field of the json can be changed by using the --json-key flag in preprocess_data.py The other metadata are optional and are not used in training.

The loose json is then processed into a binary format for training. To convert the json into mmap, cached index file, or the lazy loader format use preprocess_data.py. Set the --dataset-impl flag to mmap, cached, or lazy, respectively (default is mmap). An example script to prepare data for BERT training is:

python tools/preprocess_data.py \
       --input my-corpus.json \
       --output-prefix my-bert \
       --vocab bert-vocab.txt \
       --dataset-impl mmap \
       --tokenizer-type BertWordPieceLowerCase \
       --split-sentences

The output will be two files named, in this case, my-bert_text_sentence.bin and my-bert_text_sentence.idx. The --data-path specified in later BERT training is the full path and new filename, but without the file extension.

For T5 use the same preprocessing as BERT, perhaps renaming it to:

       --output-prefix my-t5 \

Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:

python tools/preprocess_data.py \
       --input my-corpus.json \
       --output-prefix my-gpt2 \
       --vocab gpt2-vocab.json \
       --dataset-impl mmap \
       --tokenizer-type GPT2BPETokenizer \
       --merge-file gpt2-merges.txt \
       --append-eod

Here the output files are named my-gpt2_text_document.bin and my-gpt2_text_document.idx. As before, in GPT training, use the longer name without the extension as --data-path.

Further command line arguments are described in the source file preprocess_data.py.

BERT Pretraining

The examples/pretrain_bert.sh script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at --lr to a minimum set by --min-lr over --lr-decay-iters iterations. The fraction of training iterations used for warmup is set by --lr-warmup-fraction. While this is single GPU training, the batch size specified by --micro-batch-size is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reaches global-batch-size which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with --seed). We use train-iters as the training iterations requested. Alternatively, one can provide --train-samples which is total number of samples to train on. If this option is present, then instead of providing --lr-decay-iters, one will need to provide --lr-decay-samples.

The logging, checkpoint-saving, and evaluation intervals are specified. Checkpointing the activations facilitates the training of larger models and/or batches. Note that the --data-path now includes the additional _text_sentence suffix added in preprocessing, but does not include the file extensions.

Further command line arguments are described in the source file arguments.py.

To run examples/pretrain_bert.sh, make any desired modifications including setting the environment variables for CHECKPOINT_PATH, VOCAB_FILE, and DATA_PATH. Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained in Setup) and run the example script.

GPT Pretraining

The examples/pretrain_gpt.sh script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training.

It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and a json vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the --lr-decay-style has been set to cosine decay. Note that the --data-path now includes the additional _text_document suffix added in preprocessing, but does not include the file extensions.

Further command line arguments are described in the source file arguments.py.

examples/pretrain_gpt.sh can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script.

T5 Pretraining

Very similar to BERT and GPT, the examples/pretrain_t5.sh script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture:

  • --kv-channels sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5.

  • --ffn-hidden-size sets the hidden size in the feed-forward networks within a transformer layer. For BERT and GPT this defaults to 4 times the transformer hidden size, but can be configured for T5.

  • --encoder-seq-length and --decoder-seq-length set the sequence length for the encoder and decoder separately.

All of the other arguments remain as they were for BERT and GPT pretraining. Run this example with the same steps described above for the other scripts.

Distributed Pretraining

The examples/pretrain_{bert,gpt,t5}_distributed.sh scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorch documentation for further description of these environment variables. By default, multi-node training uses the nccl distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the torchrun elastic launcher (equivalent to python -m torch.distributed.run) are the only additional requirements to adopt distributed training. See any of examples/pretrain_{bert,gpt,t5}_distributed.sh for more details.

We use two types of parallelism: data and model parallelism. We facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use --DDP-impl local or --DDP-impl torch, respectively. As expected, Torch distributed data parallelism is more efficient at larger model sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time.

Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 of our paper), add the --tensor-model-parallel-size flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use sequence parallelism specify --sequence-parallel, which requires tensor model parallel as it split among the same GPUs (more details in Section 4.2.2 of our paper).

To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches, see Section 2.2 of our paper), use the --pipeline-model-parallel-size flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each).

We have examples of how to use these two different forms of model parallelism the example scripts ending in distributed_with_mp.sh:

Other than these minor changes, the distributed training is identical to the training on a single GPU.

The interleaved pipelining schedule (more details in Section 2.2.2 of our paper) can be enabled using the --num-layers-per-virtual-pipeline-stage argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with NUM_LAYERS / PIPELINE_MP_SIZE transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)) should be divisible by the PIPELINE_MP_SIZE when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (PIPELINE_MP_SIZE=2).

Activation Checkpointing and Recomputation

To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We support two levels of recompute granularity: selective and full. Selective recomputation is the default and recommended in almost all cases. It saves the activations that take less space and are expensive to recompute and recomputes activations that take a lot of space but are relatively cheap to recompute (see our paper for details). To enable selective activation recompute simply use --recompute-activations.

For cases where memory is very tight, full checkpointing saves just the inputs to a transformer layer, or a block of transformer layers, and recomputes everything else. To turn on full activation recompute use --recompute-granularity full. When using full activation recomputation, there are two methods: uniform and block, chosen using the --recompute-method argument.

  • Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed.

  • Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop.

Distributed Optimizer

Usage: --use-distributed-optimizer. Compatible with all model and data types.

The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described in ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed (although the fp32 main params are still distributed, as they are separate from the bf16 model params).

Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):

Non-distributed optim Distributed optim
fp16 param, fp16 grads 20 4 + 16/d
bf16 param, fp32 grads 18 6 + 12/d
fp32 param, fp32 grads 16 8 + 8/d

FlashAttention

Usage: --use-flash-attn. Support attention head dimensions at most 128.

FlashAttention is a fast and memory-efficient algorithm to compute exact attention. It speeds up model training and reduces memory requirement.

To install FlashAttention:

pip install flash-attn

GPT-3 Example

In examples/pretrain_gpt3_175B.sh we have provided an example of how to configure Megatron to run GPT-3 with 175 billion parameters on 1024 GPUs. The script is designed for slurm with pyxis plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options global-batch-size 1536 and rampup-batch-size 16 16 5859375, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incrmeental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights.

With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.

Retro

See:

  • tools/retro/README.md for an overview.
  • tools/retro/examples/get_preprocess_cmd.sh for an example of common preprocessing arguments.
  • tools/retro/examples/preprocess_data.sh for an example of how to preprocess data.
  • tools/retro/examples/pretrain_model.sh for an example of how to pretrain a model.

Retro is a retrieval-enhanced model that is based on GPT. As described in Improving language models by retrieving from trillions of tokens, Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters.

Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see tools/retro/README.md for a detailed overview.

Evaluation and Tasks

We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the --finetune flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the --finetune flag before continuing, otherwise the training will start again from the beginning.

Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.

python tools/checkpoint_util.py \
        --model-type GPT \
        --load-dir checkpoints/gpt3_tp4_pp4 \
        --save-dir checkpoints/gpt3_tp2_pp2 \
        --target-tensor-parallel-size 2 \
        --target-pipeline-parallel-size 2

Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.

GPT Text Generation

We have included a simple REST server to use for text generation in tools/run_text_generation_server.py. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: temperature, top-kand top-p. See --help or the source file for more information. See examples/run_text_generation_server_345M.sh for an example of how to run the server.

Once the server is running you can use tools/text_generation_cli.py to query it, it takes one argument which is the host the server is running on.

tools/text_generation_cli.py localhost:5000

You can also use CURL or any other tools to query the server directly:

curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":["Hello world"], "tokens_to_generate":1}'

See megatron/text_generation_server.py for more API options.

Detoxify GPT via Self-generation

We include an example in examples/detxoify_lm/ to detoxify language models by leveraging the generative power of language models.

See examples/detxoify_lm/README.md for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus.

GPT Evaluation

We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy.

WikiText Perplexity Evaluation

For even comparison with prior works, we evaluate perplexity on the word-level WikiText-103 test dataset, and appropriately compute perplexity given the change in tokens when using our subword tokenizer.

We use the following command to run WikiText-103 evaluation on a 345M parameter model.

TASK="WIKITEXT103"

VALID_DATA=<wikitext path>.txt
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
CHECKPOINT_PATH=checkpoints/gpt2_345m

COMMON_TASK_ARGS="--num-layers 24 \
                  --hidden-size 1024 \
                  --num-attention-heads 16 \
                  --seq-length 1024 \
                  --max-position-embeddings 1024 \
                  --fp16 \
                  --vocab-file $VOCAB_FILE"

python tasks/main.py \
       --task $TASK \
       $COMMON_TASK_ARGS \
       --valid-data $VALID_DATA \
       --tokenizer-type GPT2BPETokenizer \
       --merge-file $MERGE_FILE \
       --load $CHECKPOINT_PATH \
       --micro-batch-size 8 \
       --activations-checkpoint-method uniform \
       --log-interval 10 \
       --no-load-optim \
       --no-load-rng

LAMBADA Cloze Accuracy

To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of the LAMBADA dataset.

We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the --strict-lambada flag should be used to require whole word matching. Make that lambada is part of the file path.

TASK="LAMBADA"

VALID_DATA=<lambada path>.json
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
CHECKPOINT_PATH=checkpoints/gpt2_345m
COMMON_TASK_ARGS=<same as those in WikiText Perplexity Evaluation above>

python tasks/main.py \
       --task $TASK \
       $COMMON_TASK_ARGS \
       --valid-data $VALID_DATA \
       --tokenizer-type GPT2BPETokenizer \
       --strict-lambada \
       --merge-file $MERGE_FILE \
       --load $CHECKPOINT_PATH \
       --micro-batch-size 8 \
       --activations-checkpoint-method uniform \
       --log-interval 10 \
       --no-load-optim \
       --no-load-rng

Further command line arguments are described in the source file main.py

BERT Task Evaluation

RACE Evaluation

The following script finetunes the BERT model for evaluation on the RACE dataset. The TRAIN_DATA and VALID_DATA directory contain the RACE dataset as separate .txt files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line.

TRAIN_DATA="data/RACE/train/middle"
VALID_DATA="data/RACE/dev/middle \
            data/RACE/dev/high"
VOCAB_FILE=bert-vocab.txt
PRETRAINED_CHECKPOINT=checkpoints/bert_345m
CHECKPOINT_PATH=checkpoints/bert_345m_race
COMMON_TASK_ARGS="--num-layers 24 \
                  --hidden-size 1024 \
                  --num-attention-heads 16 \
                  --seq-length 512 \
                  --max-position-embeddings 512 \
                  --fp16 \
                  --vocab-file $VOCAB_FILE"

COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
                      --valid-data $VALID_DATA \
                      --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
                      --activations-checkpoint-method uniform \
                      --save-interval 10000 \
                      --save $CHECKPOINT_PATH \
                      --log-interval 100 \
                      --eval-interval 1000 \
                      --eval-iters 10 \
                      --weight-decay 1.0e-1"

python tasks/main.py \
       --task RACE \
       $COMMON_TASK_ARGS \
       $COMMON_TASK_ARGS_EXT \
       --tokenizer-type BertWordPieceLowerCase \
       --epochs 3 \
       --micro-batch-size 4 \
       --lr 1.0e-5 \
       --lr-warmup-fraction 0.06

MNLI Evaluation

The following script finetunes the BERT model for evaluation with the MultiNLI sentence pair corpus. Because the matching tasks are quite similar, the script can be quickly tweaked to work with the Quora Question Pairs (QQP) dataset as well.

TRAIN_DATA="data/glue_data/MNLI/train.tsv"
VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
            data/glue_data/MNLI/dev_mismatched.tsv"
PRETRAINED_CHECKPOINT=checkpoints/bert_345m
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m_mnli
COMMON_TASK_ARGS=<same as those in RACE Evaluation above>
COMMON_TASK_ARGS_EXT=<same as those in RACE Evaluation above>

python tasks/main.py \
       --task MNLI \
       $COMMON_TASK_ARGS \
       $COMMON_TASK_ARGS_EXT \
       --tokenizer-type BertWordPieceLowerCase \
       --epochs 5 \
       --micro-batch-size 8 \
       --lr 5.0e-5 \
       --lr-warmup-fraction 0.065

Datasets

We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced.

Collecting Wikipedia Training Data

We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to download the latest dump, extract the text with WikiExtractor.py, and then apply any necessary cleanup to convert it into plain text."

We recommend using the --json argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset by nltk punctuation standardization. For BERT training, use the --split-sentences flag to preprocess_data.py as described above to include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the --split-sentences flag.

Collecting GPT Webtext Data

We utilize the publicly available OpenWebText library from jcpeterson and eukaryote31's work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our openwebtext directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content.

Reproducibility

Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary).

There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. They are only applicable when using NGC containers >=22.05. The following workarounds should be applied in cases where reproducibility is required:

  1. When training using the --bf16 option the backward pass of torch.nn.functional.embedding is non-deterministic. If reproducibility is required you should also use the option --embedding-weights-in-fp32. The speed and memory impact of this change is negligible.
  2. Also when training using --bf16, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option --no-bias-gelu-fusion should be used.
  3. Flash attention is non-deterministic. If reproducibility is required do not use --use-flash-attn.

These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue.

Packages

No packages published

Languages

  • Python 92.3%
  • C++ 4.9%
  • Shell 1.3%
  • Cuda 1.2%
  • C 0.2%
  • HTML 0.1%