link
[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构
0x00 摘要
NVIDIA Megatron 是一个基于 PyTorch 的分布式训练框架,用来训练超大Transformer语言模型,其通过综合应用了数据并行,Tensor并行和Pipeline并行来复现 GPT3,值得我们深入分析其背后机理。
本系列大概有6~7篇文章,通过论文和源码和大家一起学习研究。本文将对 Megatron 的基本架构做一下梳理。
本系列其他文章为:
[源码解析] 模型并行分布式训练Megatron (1) --- 论文 & 基础
0x01 启动
1.1 分布式启动
启动脚本在 examples/pretrain_bert_distributed.sh,其利用了 torch.distributed.launch 来启动多个进程。具体业务代码是 pretrain_bert.py。
因为 GPUS_PER_NODE 是8,所以 nproc_per_node 是8,这样,在本机上就启动了8个进程,每个进程之中含有模型的一部分 。进程的 rank 是被 torch.distributed.launch 调用 elastic 自动分配的 。
# !/bin/bash
GPUS_PER_NODE=8 # Change for multinode config MASTER_ADDR=localhost MASTER_PORT=6000 NNODES=1 NODE_RANK=0 WORLD_SIZE= ( ( (( (( GPUS_PER_NODE*$NNODES))
DATA_PATH=<Specify path and file prefix>_text_sentence CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS=“–nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT”
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_bert.py –num-layers 24 –hidden-size 1024 –num-attention-heads 16 –micro-batch-size 4 –global-batch-size 32 –seq-length 512 –max-position-embeddings 512 –train-iters 1000000 –save $CHECKPOINT_PATH –load $CHECKPOINT_PATH –data-path $DATA_PATH –vocab-file bert-vocab.txt –data-impl mmap –split 949,50,1 –distributed-backend nccl –lr 0.0001 –lr-decay-style linear –min-lr 1.0e-5 –lr-decay-iters 990000 –weight-decay 1e-2 –clip-grad 1.0 –lr-warmup-fraction .01 –log-interval 100 –save-interval 10000 –eval-interval 1000 –eval-iters 10 –fp16
1.2 构造基础
pretrain_bert.py 会调用 pretrain 进行预训练。
if __name__ == "__main__" :
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={<span class="hljs-string">'tokenizer_type'</span>: <span class="hljs-string">'BertWordPieceLowerCase'</span>})
1.2.1 获取模型
model_provider返回模型普通版本(vanilla version)。所谓vanilla,我们指的是一个简单的cpu模型,没有 fp16或 ddp,但是已经被 Megatron 改造为并行的版本。
def model_provider (pre_process=True , post_process=True ):
"""Build the model."""
print_rank_0(<span class="hljs-string">'building BERT model ...'</span>)
args = get_args()
num_tokentypes = <span class="hljs-number">2</span> <span class="hljs-keyword">if</span> args.bert_binary_head <span class="hljs-keyword">else</span> <span class="hljs-number">0</span>
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=<span class="hljs-literal">True</span>,
pre_process=pre_process,
post_process=post_process)
<span class="hljs-keyword">return</span> model
1.2.2 获取数据集
train_valid_test_datasets_provider 会接受train/valid/test数据集的大小,并返回 “train,valid,test” 数据集。
def train_valid_test_datasets_provider (train_val_test_num_samples ):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(<span class="hljs-string">'> building train, validation, and test datasets '</span>
<span class="hljs-string">'for BERT ...'</span>)
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(<span class="hljs-keyword">not</span> args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0(<span class="hljs-string">"> finished creating BERT datasets ..."</span>)
<span class="hljs-keyword">return</span> train_ds, valid_ds, test_ds
1.2.3 步进函数
forward_step函数接受一个“数据迭代器”和“模型”,并返回一个“loss”标量,该标量带有一个字典,其中key:value是希望在训练期间监视的信息,例如“lm loss:value”。还要求此函数将“batch generator”添加到timers类中。
def forward_step (data_iterator, model ):
"""Forward step."""
args = get_args()
<span class="hljs-comment"># Get the batch.</span>
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
data_iterator)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> args.bert_binary_head:
types = <span class="hljs-literal">None</span>
<span class="hljs-comment"># Forward pass through the model.</span>
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
<span class="hljs-keyword">return</span> output_tensor, partial(loss_func, loss_mask, sentence_order)
1.2.3.1 广播数据
forward_step 会调用 get_batch 获取batch 数据,其内部会从迭代器获取数据,然后使用broadcast_data
函数把输入数据从 rank 0 广播到所有tensor-model-parallel 其他 ranks之上。
注意,数据并行是把不同数据加载到不同的rank之上,而 Tensor模型并行组之中每个rank都加载同样数据 。
def get_batch (data_iterator ):
"""Build the batch."""
<span class="hljs-comment"># Items and their type.</span>
keys = [<span class="hljs-string">'text'</span>, <span class="hljs-string">'types'</span>, <span class="hljs-string">'labels'</span>, <span class="hljs-string">'is_random'</span>, <span class="hljs-string">'loss_mask'</span>, <span class="hljs-string">'padding_mask'</span>]
datatype = torch.int64
<span class="hljs-comment"># Broadcast data.</span>
<span class="hljs-keyword">if</span> data_iterator <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
data = <span class="hljs-built_in">next</span>(data_iterator) <span class="hljs-comment"># 获取数据</span>
<span class="hljs-keyword">else</span>:
data = <span class="hljs-literal">None</span>
data_b = mpu.broadcast_data(keys, data, datatype) <span class="hljs-comment"># 把数据广播到各个GPU</span>
<span class="hljs-comment"># Unpack.</span>
tokens = data_b[<span class="hljs-string">'text'</span>].long()
types = data_b[<span class="hljs-string">'types'</span>].long()
sentence_order = data_b[<span class="hljs-string">'is_random'</span>].long()
loss_mask = data_b[<span class="hljs-string">'loss_mask'</span>].<span class="hljs-built_in">float</span>()
lm_labels = data_b[<span class="hljs-string">'labels'</span>].long()
padding_mask = data_b[<span class="hljs-string">'padding_mask'</span>].long()
<span class="hljs-keyword">return</span> tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
broadcast_data 在每个model parallel group之上,把数据从rank 0发送到同组其他成员。
def broadcast_data (keys, data, datatype ):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Arguments:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""</span>
<span class="hljs-comment"># Build (key, size) and (key, number of elements) dictionaries along</span>
<span class="hljs-comment"># with the total number of elements on all ranks.</span>
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
data)
<span class="hljs-comment"># Pack on rank zero.</span>
<span class="hljs-keyword">if</span> get_tensor_model_parallel_rank() == <span class="hljs-number">0</span>: <span class="hljs-comment"># rank 0才压缩</span>
<span class="hljs-comment"># Check that all keys have the same data type.</span>
_check_data_types(keys, data, datatype)
<span class="hljs-comment"># Flatten the data associated with the keys</span>
flatten_data = torch.cat(
[data[key].contiguous().view(-<span class="hljs-number">1</span>) <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys], dim=<span class="hljs-number">0</span>).cuda()
<span class="hljs-keyword">else</span>:
flatten_data = torch.empty(total_numel,
device=torch.cuda.current_device(),
dtype=datatype)
<span class="hljs-comment"># Broadcast</span>
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
<span class="hljs-comment"># Unpack</span>
output = {}
offset = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(<span class="hljs-number">0</span>, offset, numel).view(size)
offset += numel
<span class="hljs-keyword">return</span> output
get_tensor_model_parallel_src_rank 计算与张量模型并行组中第一个local rank对应的全局rank。
def get_tensor_model_parallel_src_rank ():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
逻辑图具体如下,三个不同的函数分别为预训练提供不同的功能输入,做到了解耦。
0x02 Pretrain
BERT训练主要分为两步:
Pre-train:pre-train是迁移学习的基础,是训练token-level的语义理解。
Fine-tuning:在已经训练好的语言模型基础之上,加入特定领域(比如金融医疗)的参数来重新训练,比如对于分类问题就可以在pre-train模型基础之上加上一个softmax,再使用语料 fine-tune。
Pre-train 主要如下:
具体代码如下:
def pretrain (train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
extra_args_provider=None ,
args_defaults={} ):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
"""</span>
<span class="hljs-comment"># Initalize and get arguments, timers, and Tensorboard writer.</span>
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
<span class="hljs-comment"># Adjust the startup time so it reflects the largest value.</span>
<span class="hljs-comment"># This will be closer to what scheduler will see (outside of</span>
<span class="hljs-comment"># image ... launches.</span>
<span class="hljs-keyword">global</span> _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
args = get_args()
timers = get_timers()
<span class="hljs-comment"># Model, optimizer, and learning rate. 使用model_provider设置模型、优化器和lr计划</span>
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model_type)
<span class="hljs-comment"># Data stuff. 调用train_val_test_data_provider以获取train/val/测试数据集</span>
<span class="hljs-keyword">if</span> args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(model))
]
train_data_iterator = [data_iterators[<span class="hljs-number">0</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
valid_data_iterator = [data_iterators[<span class="hljs-number">1</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
test_data_iterator = [data_iterators[<span class="hljs-number">2</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
<span class="hljs-keyword">else</span>:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
iteration = <span class="hljs-number">0</span>
<span class="hljs-keyword">if</span> args.do_train <span class="hljs-keyword">and</span> args.train_iters > <span class="hljs-number">0</span>:
iteration = train(forward_step_func, <span class="hljs-comment"># 训练模型</span>
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
<span class="hljs-keyword">if</span&