一直以来都是用的单机单卡训练模型,虽然很多情况下已经足够了,但总有一些情况得上分布式训练:
- 模型大到一张卡放不下;
- 单张卡batch size不敢设太大,训练速度慢;
- 当你有好几张卡,不想浪费;
展示一下技术。
由于还没遇到过一张显卡放不下整个模型的情况,本文的分布式训练仅限数据并行。主要从数据并行的原理和一些简单的实践例子进行说明。
原理介绍
与每个step一个batch数据相比,数据并行是指每个step用更多的数据(多个batch)进行计算——即多个batch的数据并行进行前向计算。既然是并行,那么就涉及到多张卡一起计算。单卡和多卡训练过程如下图1所示,主要有三个过程:
- 各卡分别计算损失和梯度,即图中红线部分;
- 所以梯度整合到主device,即图中蓝线部分;
- 主device进行参数更新,并将新模型拷贝到其他device上,即图中绿线部分。

如果不使用数据并行,在显存足够的情况下,我们可以将batch_size设大,这和数据并行的区别在哪呢?如果只将batch_size设大,计算还是在一张卡上完成,速度相对来说是不如将数据均分后放在不同卡上并行计算的。当然,考虑到卡之间的通信问题,要发挥多卡并行的力量需要进行一定权衡。
torch中主要有两种数据并行方式:DP和DDP。
DataParallel
DP是较简单的一种数据并行方式,直接将模型复制到多个GPU上并行计算,每个GPU计算batch中的一部分数据,各自完成前向和反向后,将梯度汇总到主GPU上。其基本流程:
- 加载模型、数据至内存;
- 创建DP模型;
- DP模型的forward过程:
- 一个batch的数据均分到不同device上;
- 为每个device复制一份模型;
- 至此,每个device上有模型和一份数据,并行进行前向传播;
- 收集各个device上的输出;
- 每个device上的模型反向传播后,收集梯度到主device上,更新主device上的模型,将模型广播到其他device上;
- 3-4循环。
在DP中,只有一个主进程,主进程下有多个线程,每个线程管理一个device的训练。因此,DP中内存中只存在一份数据,各个线程间是共享这份数据的。DP和Parameter Server的方式很像。
小样
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):
def __init__(self, data, target):
self.data = data
self.target = target
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.target[idx]
# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):
def __init__(self, input_dim):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(input_dim, 1)
def forward(self, x):
return torch.sigmoid(self.fc(x))
# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(