[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构

link

[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构

0x00 摘要

NVIDIA Megatron 是一个基于 PyTorch 的分布式训练框架,用来训练超大Transformer语言模型,其通过综合应用了数据并行,Tensor并行和Pipeline并行来复现 GPT3,值得我们深入分析其背后机理。

本系列大概有6~7篇文章,通过论文和源码和大家一起学习研究。本文将对 Megatron 的基本架构做一下梳理。

本系列其他文章为:

[源码解析] 模型并行分布式训练Megatron (1) --- 论文 & 基础

0x01 启动

1.1 分布式启动

启动脚本在 examples/pretrain_bert_distributed.sh,其利用了 torch.distributed.launch 来启动多个进程。具体业务代码是 pretrain_bert.py。

因为 GPUS_PER_NODE 是8,所以 nproc_per_node 是8,这样,在本机上就启动了8个进程,每个进程之中含有模型的一部分进程的 rank 是被 torch.distributed.launch 调用 elastic 自动分配的

#!/bin/bash

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE= ( ( (( ((GPUS_PER_NODE*$NNODES))

DATA_PATH=<Specify path and file prefix>_text_sentence
CHECKPOINT_PATH=<Specify path>

DISTRIBUTED_ARGS=“–nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT”

python -m torch.distributed.launch $DISTRIBUTED_ARGS
pretrain_bert.py
–num-layers 24
–hidden-size 1024
–num-attention-heads 16
–micro-batch-size 4
–global-batch-size 32
–seq-length 512
–max-position-embeddings 512
–train-iters 1000000
–save $CHECKPOINT_PATH
–load $CHECKPOINT_PATH
–data-path $DATA_PATH
–vocab-file bert-vocab.txt
–data-impl mmap
–split 949,50,1
–distributed-backend nccl
–lr 0.0001
–lr-decay-style linear
–min-lr 1.0e-5
–lr-decay-iters 990000
–weight-decay 1e-2
–clip-grad 1.0
–lr-warmup-fraction .01
–log-interval 100
–save-interval 10000
–eval-interval 1000
–eval-iters 10
–fp16

1.2 构造基础

pretrain_bert.py 会调用 pretrain 进行预训练。

if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
         ModelType.encoder_or_decoder,
         forward_step, args_defaults={<span class="hljs-string">'tokenizer_type'</span>: <span class="hljs-string">'BertWordPieceLowerCase'</span>})

1.2.1 获取模型

model_provider返回模型普通版本(vanilla version)。所谓vanilla,我们指的是一个简单的cpu模型,没有 fp16或 ddp,但是已经被 Megatron 改造为并行的版本。

def model_provider(pre_process=True, post_process=True):
    """Build the model."""
print_rank_0(<span class="hljs-string">'building BERT model ...'</span>)

args = get_args()
num_tokentypes = <span class="hljs-number">2</span> <span class="hljs-keyword">if</span> args.bert_binary_head <span class="hljs-keyword">else</span> <span class="hljs-number">0</span>
model = BertModel(
    num_tokentypes=num_tokentypes,
    add_binary_head=args.bert_binary_head,
    parallel_output=<span class="hljs-literal">True</span>,
    pre_process=pre_process,
    post_process=post_process)

<span class="hljs-keyword">return</span> model

1.2.2 获取数据集

train_valid_test_datasets_provider 会接受train/valid/test数据集的大小,并返回 “train,valid,test” 数据集。

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""
    args = get_args()
print_rank_0(<span class="hljs-string">'&gt; building train, validation, and test datasets '</span>
             <span class="hljs-string">'for BERT ...'</span>)
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
    data_prefix=args.data_path,
    data_impl=args.data_impl,
    splits_string=args.split,
    train_valid_test_num_samples=train_val_test_num_samples,
    max_seq_length=args.seq_length,
    masked_lm_prob=args.mask_prob,
    short_seq_prob=args.short_seq_prob,
    seed=args.seed,
    skip_warmup=(<span class="hljs-keyword">not</span> args.mmap_warmup),
    binary_head=args.bert_binary_head)
print_rank_0(<span class="hljs-string">"&gt; finished creating BERT datasets ..."</span>)

<span class="hljs-keyword">return</span> train_ds, valid_ds, test_ds

1.2.3 步进函数

forward_step函数接受一个“数据迭代器”和“模型”,并返回一个“loss”标量,该标量带有一个字典,其中key:value是希望在训练期间监视的信息,例如“lm loss:value”。还要求此函数将“batch generator”添加到timers类中。

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
<span class="hljs-comment"># Get the batch.</span>
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
    data_iterator)

<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> args.bert_binary_head:
    types = <span class="hljs-literal">None</span>

<span class="hljs-comment"># Forward pass through the model.</span>
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
                      lm_labels=lm_labels)

<span class="hljs-keyword">return</span> output_tensor, partial(loss_func, loss_mask, sentence_order)

1.2.3.1 广播数据

forward_step 会调用 get_batch 获取batch 数据,其内部会从迭代器获取数据,然后使用broadcast_data函数把输入数据从 rank 0 广播到所有tensor-model-parallel 其他 ranks之上。

注意,数据并行是把不同数据加载到不同的rank之上,而 Tensor模型并行组之中每个rank都加载同样数据

def get_batch(data_iterator):
    """Build the batch."""
<span class="hljs-comment"># Items and their type.</span>
keys = [<span class="hljs-string">'text'</span>, <span class="hljs-string">'types'</span>, <span class="hljs-string">'labels'</span>, <span class="hljs-string">'is_random'</span>, <span class="hljs-string">'loss_mask'</span>, <span class="hljs-string">'padding_mask'</span>]
datatype = torch.int64

<span class="hljs-comment"># Broadcast data.</span>
<span class="hljs-keyword">if</span> data_iterator <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
    data = <span class="hljs-built_in">next</span>(data_iterator) <span class="hljs-comment"># 获取数据</span>
<span class="hljs-keyword">else</span>:
    data = <span class="hljs-literal">None</span>
data_b = mpu.broadcast_data(keys, data, datatype) <span class="hljs-comment"># 把数据广播到各个GPU</span>

<span class="hljs-comment"># Unpack.</span>
tokens = data_b[<span class="hljs-string">'text'</span>].long()
types = data_b[<span class="hljs-string">'types'</span>].long()
sentence_order = data_b[<span class="hljs-string">'is_random'</span>].long()
loss_mask = data_b[<span class="hljs-string">'loss_mask'</span>].<span class="hljs-built_in">float</span>()
lm_labels = data_b[<span class="hljs-string">'labels'</span>].long()
padding_mask = data_b[<span class="hljs-string">'padding_mask'</span>].long()

<span class="hljs-keyword">return</span> tokens, types, sentence_order, loss_mask, lm_labels, padding_mask

broadcast_data 在每个model parallel group之上,把数据从rank 0发送到同组其他成员。

def broadcast_data(keys, data, datatype):
    """Broadcast data from rank zero of each model parallel group to the
    members of the same model parallel group.
Arguments:
    keys: list of keys in the data disctionary to be broadcasted
    data: data dictionary of string keys and cpu tensor values.
    datatype: torch data type of all tensors in data associated
              with keys.
"""</span>
<span class="hljs-comment"># Build (key, size) and (key, number of elements) dictionaries along</span>
<span class="hljs-comment"># with the total number of elements on all ranks.</span>
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
                                                                      data)

<span class="hljs-comment"># Pack on rank zero.</span>
<span class="hljs-keyword">if</span> get_tensor_model_parallel_rank() == <span class="hljs-number">0</span>: <span class="hljs-comment"># rank 0才压缩</span>
    <span class="hljs-comment"># Check that all keys have the same data type.</span>
    _check_data_types(keys, data, datatype)
    <span class="hljs-comment"># Flatten the data associated with the keys</span>
    flatten_data = torch.cat(
        [data[key].contiguous().view(-<span class="hljs-number">1</span>) <span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys], dim=<span class="hljs-number">0</span>).cuda()
<span class="hljs-keyword">else</span>:
    flatten_data = torch.empty(total_numel,
                               device=torch.cuda.current_device(),
                               dtype=datatype)

<span class="hljs-comment"># Broadcast</span>
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
                            group=get_tensor_model_parallel_group())

<span class="hljs-comment"># Unpack</span>
output = {}
offset = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> keys:
    size = key_size[key]
    numel = key_numel[key]
    output[key] = flatten_data.narrow(<span class="hljs-number">0</span>, offset, numel).view(size)
    offset += numel

<span class="hljs-keyword">return</span> output

get_tensor_model_parallel_src_rank 计算与张量模型并行组中第一个local rank对应的全局rank。

def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size

逻辑图具体如下,三个不同的函数分别为预训练提供不同的功能输入,做到了解耦。

0x02 Pretrain

BERT训练主要分为两步:

  • Pre-train:pre-train是迁移学习的基础,是训练token-level的语义理解。
  • Fine-tuning:在已经训练好的语言模型基础之上,加入特定领域(比如金融医疗)的参数来重新训练,比如对于分类问题就可以在pre-train模型基础之上加上一个softmax,再使用语料 fine-tune。

Pre-train 主要如下:

  • 初始化Megatron。

  • 使用model_provider设置模型、优化器和lr计划。

  • 调用train_val_test_data_provider以获取train/val/test数据集。

  • 使用forward_step_func训练模型。

具体代码如下:

def pretrain(train_valid_test_dataset_provider,
             model_provider,
             model_type,
             forward_step_func,
             extra_args_provider=None,
             args_defaults={}):
    """Main training program.
This function will run the followings in the order provided:
    1) initialize Megatron.
    2) setup model, optimizer and lr schedule using the model_provider.
    3) call train_val_test_data_provider to get train/val/test datasets.
    4) train the modle using the forward_step_func.
"""</span>

<span class="hljs-comment"># Initalize and get arguments, timers, and Tensorboard writer.</span>
initialize_megatron(extra_args_provider=extra_args_provider,
                    args_defaults=args_defaults)

<span class="hljs-comment"># Adjust the startup time so it reflects the largest value.</span>
<span class="hljs-comment"># This will be closer to what scheduler will see (outside of</span>
<span class="hljs-comment"># image ... launches.</span>
<span class="hljs-keyword">global</span> _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
                             op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()

args = get_args()
timers = get_timers()

<span class="hljs-comment"># Model, optimizer, and learning rate. 使用model_provider设置模型、优化器和lr计划</span>
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
                                                           model_type)

<span class="hljs-comment"># Data stuff. 调用train_val_test_data_provider以获取train/val/测试数据集</span>
<span class="hljs-keyword">if</span> args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
    all_data_iterators = [
        build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
        <span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(model))
    ]
    train_data_iterator = [data_iterators[<span class="hljs-number">0</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
    valid_data_iterator = [data_iterators[<span class="hljs-number">1</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
    test_data_iterator = [data_iterators[<span class="hljs-number">2</span>] <span class="hljs-keyword">for</span> data_iterators <span class="hljs-keyword">in</span> all_data_iterators]
<span class="hljs-keyword">else</span>:
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)

iteration = <span class="hljs-number">0</span>
<span class="hljs-keyword">if</span> args.do_train <span class="hljs-keyword">and</span> args.train_iters &gt; <span class="hljs-number">0</span>:
    iteration = train(forward_step_func, <span class="hljs-comment"># 训练模型</span>
                      model, optimizer, lr_scheduler,
                      train_data_iterator, valid_data_iterator)

<span class="hljs-keyword">if</span&
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值