简明Pytorch分布式训练 — DistributedDataParallel 实践

上一次的Pytorch单机多卡训练主要介绍了Pytorch里分布式训练的基本原理,DP和DDP的大致过程,以及二者的区别,并分别写了一个小样作为参考。小样毕竟还是忽略了很多细节和工程实践时的一些处理方式的。实践出真知,今天(简单)写一个实际可用的 DDP 训练中样,检验一下自己对 DDP 的了解程度,也作为以后分布式训练的模板,逐渐完善。

DDP 流程回顾

Pytorch单机多卡训练中主要从DDP的原理层面进行了介绍,这次从实践的角度来看看。

试问把大象装进冰箱需要几步?三步,先打开冰箱,再把大象装进去,最后关上冰箱门。那实践中DDP又分为几步呢?

  1. 开始训练/预测前的准备工作。
  2. 开始训练/验证!

看,DDP比把大象装进冰箱简单多了!

一些前置

虽然DDP很简单,但是得先了解一些前置知识。有一些DDP相关的函数是必须要了解的1,主要是和分布式通信相关的函数:

  • torch.distributed.init_process_group2。初始化分布式进程组。由于是单机多卡看的分布式,因此其实是将任务分布在多个进程中。
  • torch.distributed.barrier()3。用于进程同步。每个进程进入这个函数后都会被阻塞,当所有进程都进入这个函数后,阻塞解除,继续执行后续的代码。
  • torch.distributed.all_gather4。收集不同进程中的tensor。这个用到的比较多的,但有一点不太好理解。解释一下:某些变量是不同进程中都会有的,比如loss,如果要把各个进程中计算得到的loss汇总到一起,就需要进程间通信,把loss收集起来,all_gather干的就是收集不同进程中的某个tensor的事儿。
  • local_rank。节点上device的标识。在单机多卡的模式下,每个device的local_rank是唯一的,一般我们都在0号设备上操作。该参数不需要手动传参,在执行DDP时会自动设置该参数。当为非DDP时,该参数值为-1。
  • torch.nn.parallel.DistributedDataParallel5。在初始化好的分布式环境中,创建分布式的模型,负责。

好了,你已经具备DDP实践的主要前置知识了,一起开启DDP实践吧!

DDP 实践

训练前

训练前的准备工作,主要包含:

  • 初始化分布式进程组,使后续各个进程间能够通信。
  • 准备好数据和模型。
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600))
args.device = device

X, Y = make_classification(n_samples=25000, n_features=N_DIM, n_classes=2, random_state=args.seed)  # 每个进程都会拥有这样一份数据集
X = torch.tensor(X).float()
Y = torch.tensor(Y).float()

train_X, train_Y = X[:20000], Y[:20000]
val_X, val_Y = X[20000:], Y[20000:]
train_dataset = SimpleDataset(train_X, train_Y)
val_dataset = SimpleDataset(val_X, val_Y)

model = SimpleModel(N_DIM)
if args.local_rank not in (-1, 0):
    torch.distributed.barrier()
else:
    # 只需要在 0 号设备上加载预训练好的参数即可,后续调用DDP时会自动复制到其他卡上
    if args.ckpt_path is not None:
        model.load_state_dict(torch.load(args.ckpt_path))
    torch.distributed.barrier()

训练时

与单卡训练不同,DDP训练主要有三个改动地方:

  • 创建数据加载器时,需要创建特定的采样器,以保证每个进程使用的数据是不重复的
  • 创建分布式的模型。
  • 收集一些辅助信息时,例如loss、prediction等,需要all_gather将不同进程中的数据收集到一起。
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)

train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)

...

def gather_tensors(target):
    """
    target should be a tensor
    """
    target_list = [torch.ones_like(target) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(target_list, target)
    ret = torch.hstack(target_list)
    
    return ret

for epoch in train_iterator:
    if args.local_rank != -1:
        train_dataloader.sampler.set_epoch(epoch)  # 必须设置的!确保每一轮的随机种子是固定的!

    model.zero_grad()
    model.train()

    pbar = tqdm(train_dataloader, desc="Training", disable=args.local_rank not in [-1, 0])
    step_loss, step_label, step_pred = [], [], []
    for step, batch in enumerate(pbar):
        data, label = batch
        data, label = data.to(args.device), label.to(args.device)

        prediction = model(data)
        loss = loss_func(prediction, label.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1

        step_loss.append(loss.item())
        step_label.extend(label.squeeze().cpu().tolist())
        step_pred.extend(prediction.squeeze().detach().cpu().tolist())

        if step % args.logging_step == 0:
            if torch.distributed.is_initialized():  # 兼容非 DDP 的模式
                step_loss = torch.tensor(step_loss).to(args.device)
                step_label = torch.tensor(np.array(step_label
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值