Skip to content

Pretrain the model

zhezhaoa edited this page Oct 1, 2023 · 6 revisions
usage: pretrain.py [-h] [--dataset_path DATASET_PATH]
                   [--pretrained_model_path PRETRAINED_MODEL_PATH]
                   --output_model_path OUTPUT_MODEL_PATH
                   [--config_path CONFIG_PATH] [--total_steps TOTAL_STEPS]
                   [--save_checkpoint_steps SAVE_CHECKPOINT_STEPS]
                   [--report_steps REPORT_STEPS]
                   [--accumulation_steps ACCUMULATION_STEPS]
                   [--batch_size BATCH_SIZE]
                   [--instances_buffer_size INSTANCES_BUFFER_SIZE]
                   [--labels_num LABELS_NUM] [--dropout DROPOUT] [--seed SEED]
                   [--tokenizer {bert,bpe,char,space,xlmroberta,image,text_image}]
                   [--vocab_path VOCAB_PATH] [--merges_path MERGES_PATH]
                   [--spm_model_path SPM_MODEL_PATH]
                   [--do_lower_case {true,false}]
                   [--vqgan_model_path VQGAN_MODEL_PATH]
                   [--vqgan_config_path VQGAN_CONFIG_PATH]
                   [--tgt_tokenizer {bert,bpe,char,space,xlmroberta}]
                   [--tgt_vocab_path TGT_VOCAB_PATH]
                   [--tgt_merges_path TGT_MERGES_PATH]
                   [--tgt_spm_model_path TGT_SPM_MODEL_PATH]
                   [--tgt_do_lower_case {true,false}]
                   [--embedding {word,pos,seg,sinusoidalpos,patch,speech,word_patch,dual} [{word,pos,seg,sinusoidalpos,patch,speech,word_patch,dual} ...]]
                   [--tgt_embedding {word,pos,seg,sinusoidalpos,patch,speech,word_patch,dual} [{word,pos,seg,sinusoidalpos,patch,speech,word_patch,dual} ...]]
                   [--max_seq_length MAX_SEQ_LENGTH]
                   [--relative_position_embedding] [--share_embedding]
                   [--remove_embedding_layernorm]
                   [--factorized_embedding_parameterization]
                   [--encoder {transformer,rnn,lstm,gru,birnn,bilstm,bigru,gatedcnn,dual}]
                   [--decoder {None,transformer}]
                   [--mask {fully_visible,causal,causal_with_prefix}]
                   [--layernorm_positioning {pre,post}]
                   [--feed_forward {dense,gated}]
                   [--relative_attention_buckets_num RELATIVE_ATTENTION_BUCKETS_NUM]
                   [--remove_attention_scale] [--remove_transformer_bias]
                   [--layernorm {normal,t5}] [--bidirectional]
                   [--parameter_sharing] [--has_residual_attention]
                   [--has_lmtarget_bias]
                   [--target {sp,lm,mlm,bilm,cls,clr} [{sp,lm,mlm,bilm,cls,clr} ...]]
                   [--tie_weights] [--pooling {mean,max,first,last}]
                   [--image_height IMAGE_HEIGHT] [--image_width IMAGE_WIDTH]
                   [--patch_size PATCH_SIZE] [--channels_num CHANNELS_NUM]
                   [--image_preprocess IMAGE_PREPROCESS [IMAGE_PREPROCESS ...]]
                   [--sampling_rate SAMPLING_RATE]
                   [--audio_preprocess AUDIO_PREPROCESS [AUDIO_PREPROCESS ...]]
                   [--max_audio_frames MAX_AUDIO_FRAMES]
                   [--conv_layers_num CONV_LAYERS_NUM]
                   [--audio_feature_size AUDIO_FEATURE_SIZE]
                   [--conv_channels CONV_CHANNELS]
                   [--conv_kernel_sizes CONV_KERNEL_SIZES [CONV_KERNEL_SIZES ...]]
                   [--data_processor {bert,lm,mlm,bilm,albert,mt,t5,cls,prefixlm,gsg,bart,cls_mlm,vit,vilt,clip,s2t,beit,dalle}]
                   [--deep_init] [--whole_word_masking] [--span_masking]
                   [--span_geo_prob SPAN_GEO_PROB]
                   [--span_max_length SPAN_MAX_LENGTH]
                   [--learning_rate LEARNING_RATE] [--warmup WARMUP] [--fp16]
                   [--fp16_opt_level {O0,O1,O2,O3}]
                   [--optimizer {adamw,adafactor}]
                   [--scheduler {linear,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}]
                   [--world_size WORLD_SIZE]
                   [--gpu_ranks GPU_RANKS [GPU_RANKS ...]]
                   [--master_ip MASTER_IP] [--backend {nccl,gloo}]
                   [--deepspeed] [--deepspeed_config DEEPSPEED_CONFIG]
                   [--deepspeed_checkpoint_activations]
                   [--deepspeed_checkpoint_layers_num DEEPSPEED_CHECKPOINT_LAYERS_NUM]
                   [--local_rank LOCAL_RANK] [--log_path LOG_PATH]
                   [--log_level {ERROR,INFO,DEBUG,NOTSET}]
                   [--log_file_level {ERROR,INFO,DEBUG,NOTSET}]

Most pre-training models can be divided into 5 parts: embedding, encoder, target embedding (optional), decoder (optional), and target. TencentPretrain consists of these parts (--embedding --encoder --tgt_embedding --decoder --target) , and each part includes abundant modules. Users can construct a pre-training model efficiently by combining these modules. More use cases are found in Pretraining model examples. Take encoder as an example, TencentPretrain consists of abundant encoder modules, for example:

  • lstm: LSTM
  • gru: GRU
  • bilstm: bi-LSTM (different from --encoder lstm with --bidirectional , see the issue for more details)
  • gatedcnn: GatedCNN
  • transformer: Support BERT (--encoder transformer --mask fully_visible), GPT-2 (--encoder transformer --mask causal --layernorm_positioning pre), etc.

The dataset format specified in pre-training stage (--data_processor) should be coincident with the dataset format specified in pre-processing stage.

In pre-training stage, we should provide information about path, model, training environment, etc. In terms of path information, it is usually necessary to provide the dataset path (--dataset_path), configuration path (--config_path), output model path (--output_model_path). Model information is usually placed in the configuration file (--config_path). The information explicitly specified in command line can overwrite the information in configuration file. Training environment information is usually specified by --world_size and --gpu_ranks , which are discussed in the rest of the section.

There are two strategies for parameter initialization of pre-training: 1)random initialization; 2)loading a pre-trained model.

Random initialization

The example of pre-training on CPU:

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --data_processor bert \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm sp

The input of pre-training is specified by --dataset_path . The example of pre-training on single GPU (the ID of GPU is 3):

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin --gpu_ranks 3 \
                    --data_processor bert \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm sp

The example of pre-training on a single machine with 8 GPUs:

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --data_processor bert \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm sp

--world_size specifies the number of processes (and GPUs) used for pre-training.
--gpu_ranks specifies the ID for each process and GPU. The IDs are from 0 to n-1, where n is the number of processes used for pre-training. Users could use CUDA_VISIBLE_DEVICES if they want to use part of GPUs:

CUDA_VISIBLE_DEVICES=1,2,3,5 python3 pretrain.py --dataset_path dataset.pt \
                                                 --vocab_path models/google_zh_vocab.txt \
                                                 --config_path models/bert/base_config.json \
                                                 --output_model_path models/output_model.bin \
                                                 --world_size 4 --gpu_ranks 0 1 2 3 \
                                                 --data_processor bert \
                                                 --embedding word pos seg \
                                                 --encoder transformer --mask fully_visible \
                                                 --target mlm sp

--world_size is set to 4 since only 4 GPUs are used. The IDs of 4 processes (and GPUs) is 0, 1, 2, and 3, which are specified by --gpu_ranks .

The example of pre-training on two machines: each machine has 8 GPUs (16 GPUs in total). We run pretrain.py on two machines (Node-0 and Node-1) respectively. --master_ip specifies the ip:port of the master mode, which contains process (and GPU) of ID 0.

Node-0 : python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                             --config_path models/bert/base_config.json \
                             --output_model_path models/output_model.bin \
                             --world_size 16 --gpu_ranks 0 1 2 3 4 5 6 7 \
                             --total_steps 100000 --save_checkpoint_steps 10000 --report_steps 100 \
                             --master_ip tcp://9.73.138.133:12345 \
                             --data_processor bert \
                             --embedding word pos seg --encoder transformer --mask fully_visible --target mlm sp

Node-1 : python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                             --config_path models/bert/base_config.json \
                             --output_model_path models/output_model.bin \
                             --world_size 16 --gpu_ranks 8 9 10 11 12 13 14 15 \
                             --total_steps 100000 \
                             --master_ip tcp://9.73.138.133:12345 \
                             --data_processor bert \
                             --embedding word pos seg --encoder transformer --mask fully_visible --target mlm sp

The IP of Node-0 is 9.73.138.133 .
--total_steps specifies the training steps. --save_checkpoint_steps specifies how often to save the model checkpoint. We don't need to specify --save_checkpoint_steps in Node-1 since only master node saves the pre-trained model. --report_steps specifies how often to report the pre-training information. We don't need to specify --report_steps in Node-1 since the information only appears in master node. Notice that when specifying --master_ip one can not select the port which is occupied by other programs. For random initialization, pre-training usually requires larger learning rate. We recommend to use --learning_rate 1e-4 . The default value is 2e-5 .

Load the pre-trained model

We recommend to load a pre-trained model. We can specify the pre-trained model by --pretrained_model_path . The example of pre-training on CPU and single GPU:

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin --gpu_ranks 3

The example of pre-training on a single machine with 8 GPUs:

python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                    --pretrained_model_path models/google_zh_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7

The example of pre-training on two machines: each machine has 8 GPUs (16 GPUs in total):

Node-0: python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                            --pretrained_model_path models/google_zh_model.bin \
                            --config_path models/bert/base_config.json \
                            --output_model_path models/output_model.bin \
                            --world_size 16 --gpu_ranks 0 1 2 3 4 5 6 7 \
                            --master_ip tcp://9.73.138.133:12345

Node-1: python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                            --pretrained_model_path models/google_zh_model.bin \
                            --config_path models/bert/base_config.json \
                            --output_model_path models/output_model.bin \
                            --world_size 16 --gpu_ranks 8 9 10 11 12 13 14 15 \
                            --master_ip tcp://9.73.138.133:12345

The example of pre-training on three machines: each machine has 8 GPUs (24 GPUs in total):

Node-0: python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                            --pretrained_model_path models/google_zh_model.bin \
                            --config_path models/bert/base_config.json \
                            --output_model_path models/output_model.bin \
                            --world_size 24 --gpu_ranks 0 1 2 3 4 5 6 7 \
                            --master_ip tcp://9.73.138.133:12345

Node-1: python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                            --pretrained_model_path models/google_zh_model.bin \
                            --config_path models/bert/base_config.json \
                            --output_model_path models/output_model.bin \
                            --world_size 24 --gpu_ranks 8 9 10 11 12 13 14 15 \
                            --master_ip tcp://9.73.138.133:12345

Node-2: python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_zh_vocab.txt \
                            --pretrained_model_path models/google_zh_model.bin \
                            --config_path models/bert/base_config.json \
                            --output_model_path models/output_model.bin \
                            --world_size 24 --gpu_ranks 16 17 18 19 20 21 22 23 \
                            --master_ip tcp://9.73.138.133:12345

Word-based pre-training model

UER-py supports word-based pre-training models. We download cluecorpussmall_word_bert_base_seq512_model.bin and its vocabulary cluecorpussmall_word_vocab.txt. The model is pre-trained on CLUECorpusSmall corpus (jieba is used as word segmentation tool and words are separated by space):

python3 preprocess.py --corpus_path corpora/CLUECorpusSmall_word_jieba_bert.txt \
                      --tokenizer space --vocab_path models/cluecorpussmall_word_vocab.txt \
                      --dataset_path dataset.pt \
                      --processes_num 8 --dynamic_masking \
                      --data_processor bert

python3 pretrain.py --dataset_path dataset.pt \
                    --tokenizer space --vocab_path models/cluecorpussmall_word_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/cluecorpussmall_word_bert_base_seq128_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 1000000 --save_checkpoint_steps 100000 --report_steps 50000 \
                    --learning_rate 1e-4 --batch_size 64 \
                    --data_processor bert

python3 preprocess.py --corpus_path corpora/CLUECorpusSmall_word_jieba_bert.txt \
                      --tokenizer space --vocab_path models/cluecorpussmall_word_vocab.txt \
                      --dataset_path dataset.pt \
                      --processes_num 8 --dynamic_masking \
                      --data_processor bert --seq_length 512

python3 pretrain.py --dataset_path dataset.pt \
                    --pretrained_model_path models/cluecorpussmall_word_bert_base_seq128_model.bin-1000000 \
                    --tokenizer space --vocab_path models/cluecorpussmall_word_vocab.txt \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/cluecorpussmall_word_bert_base_seq512_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 250000 --save_checkpoint_steps 50000 --report_steps 10000 \
                    --learning_rate 5e-5 --batch_size 16 \
                    --data_processor bert

The following steps show an example of doing incremental pre-training upon cluecorpussmall_word_bert_base_seq512_model.bin :
Suppose that the training corpus is corpora/book_review.txt . First we do segmentation and obtain book_review_seg.txt . book_review_seg.txt is of MLM target format and words are separated by space. Then we build vocabulary upon the corpus:

python3 scripts/build_vocab.py --corpus_path corpora/book_review_seg.txt \
                               --output_path models/book_review_word_vocab.txt \
                               --delimiter space --workers_num 8 --min_count 5

Then we adapt the pre-trained model cluecorpussmall_word_bert_base_seq512_model.bin . Embedding layer and output layer before softmax are adapted according to the difference between the old vocabulary and the new vocabulary. The embedding of new word is randomly initialized. The adapted model is compatible with the new vocabulary:

python3 scripts/dynamic_vocab_adapter.py --old_model_path models/cluecorpussmall_word_bert_base_seq512_model.bin \
                                         --old_vocab_path models/cluecorpussmall_word_vocab.txt \
                                         --new_vocab_path models/book_review_word_vocab.txt \
                                         --new_model_path models/book_review_word_model.bin

Finally, we do incremental pre-training upon the adapted model book_review_word_model.bin . MLM target is used:

python3 preprocess.py --corpus_path corpora/book_review_seg.txt \
                      --vocab_path models/book_review_word_vocab.txt --tokenizer space \
                      --dataset_path book_review_word_dataset.pt \
                      --processes_num 8 --seq_length 128 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path book_review_word_dataset.pt \
                    --vocab_path models/book_review_word_vocab.txt \
                    --pretrained_model_path models/book_review_word_model.bin \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 20000 --save_checkpoint_steps 10000 --report_steps 1000 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm

In addition, We can use SentencePiece to obtain word-based pre-training model:

python3 preprocess.py --corpus_path corpora/book_review.txt \
                      --spm_model_path models/cluecorpussmall_spm.model \
                      --dataset_path book_review_word_dataset.pt \
                      --processes_num 8 --seq_length 128 --dynamic_masking \
                      --data_processor mlm

python3 pretrain.py --dataset_path book_review_word_dataset.pt \
                    --spm_model_path models/cluecorpussmall_spm.model \
                    --config_path models/bert/base_config.json \
                    --output_model_path models/output_model.bin \
                    --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
                    --total_steps 20000 --save_checkpoint_steps 10000 --report_steps 1000 \
                    --learning_rate 1e-4 \
                    --data_processor mlm \
                    --embedding word pos seg --encoder transformer --mask fully_visible --target mlm

--spm_model_path specifies the path of sentencepiece model. models/cluecorpussmall_spm.model is the sentencepiece model trained on CLUECorpusSmall corpus.

Clone this wiki locally