上一次的Pytorch单机多卡训练主要介绍了Pytorch里分布式训练的基本原理,DP和DDP的大致过程,以及二者的区别,并分别写了一个小样作为参考。小样毕竟还是忽略了很多细节和工程实践时的一些处理方式的。实践出真知,今天(简单)写一个实际可用的 DDP 训练中样,检验一下自己对 DDP 的了解程度,也作为以后分布式训练的模板,逐渐完善。
DDP 流程回顾
Pytorch单机多卡训练中主要从DDP的原理层面进行了介绍,这次从实践的角度来看看。
试问把大象装进冰箱需要几步?三步,先打开冰箱,再把大象装进去,最后关上冰箱门。那实践中DDP又分为几步呢?
- 开始训练/预测前的准备工作。
- 开始训练/验证!
看,DDP比把大象装进冰箱简单多了!
一些前置
虽然DDP很简单,但是得先了解一些前置知识。有一些DDP相关的函数是必须要了解的1,主要是和分布式通信相关的函数:
torch.distributed.init_process_group
2。初始化分布式进程组。由于是单机多卡看的分布式,因此其实是将任务分布在多个进程中。torch.distributed.barrier()
3。用于进程同步。每个进程进入这个函数后都会被阻塞,当所有进程都进入这个函数后,阻塞解除,继续执行后续的代码。torch.distributed.all_gather
4。收集不同进程中的tensor。这个用到的比较多的,但有一点不太好理解。解释一下:某些变量是不同进程中都会有的,比如loss,如果要把各个进程中计算得到的loss汇总到一起,就需要进程间通信,把loss收集起来,all_gather
干的就是收集不同进程中的某个tensor的事儿。local_rank
。节点上device的标识。在单机多卡的模式下,每个device的local_rank
是唯一的,一般我们都在0号设备上操作。该参数不需要手动传参,在执行DDP时会自动设置该参数。当为非DDP时,该参数值为-1。torch.nn.parallel.DistributedDataParallel
5。在初始化好的分布式环境中,创建分布式的模型,负责。
好了,你已经具备DDP实践的主要前置知识了,一起开启DDP实践吧!
DDP 实践
训练前
训练前的准备工作,主要包含:
- 初始化分布式进程组,使后续各个进程间能够通信。
- 准备好数据和模型。
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600))
args.device = device
X, Y = make_classification(n_samples=25000, n_features=N_DIM, n_classes=2, random_state=args.seed) # 每个进程都会拥有这样一份数据集
X = torch.tensor(X).float()
Y = torch.tensor(Y).float()
train_X, train_Y = X[:20000], Y[:20000]
val_X, val_Y = X[20000:], Y[20000:]
train_dataset = SimpleDataset(train_X, train_Y)
val_dataset = SimpleDataset(val_X, val_Y)
model = SimpleModel(N_DIM)
if args.local_rank not in (-1, 0):
torch.distributed.barrier()
else:
# 只需要在 0 号设备上加载预训练好的参数即可,后续调用DDP时会自动复制到其他卡上
if args.ckpt_path is not None:
model.load_state_dict(torch.load(args.ckpt_path))
torch.distributed.barrier()
训练时
与单卡训练不同,DDP训练主要有三个改动地方:
- 创建数据加载器时,需要创建特定的采样器,以保证每个进程使用的数据是不重复的。
- 创建分布式的模型。
- 收集一些辅助信息时,例如loss、prediction等,需要
all_gather
将不同进程中的数据收集到一起。
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)
...
def gather_tensors(target):
"""
target should be a tensor
"""
target_list = [torch.ones_like(target) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(target_list, target)
ret = torch.hstack(target_list)
return ret
for epoch in train_iterator:
if args.local_rank != -1:
train_dataloader.sampler.set_epoch(epoch) # 必须设置的!确保每一轮的随机种子是固定的!
model.zero_grad()
model.train()
pbar = tqdm(train_dataloader, desc="Training", disable=args.local_rank not in [-1, 0])
step_loss, step_label, step_pred = [], [], []
for step, batch in enumerate(pbar):
data, label = batch
data, label = data.to(args.device), label.to(args.device)
prediction = model(data)
loss = loss_func(prediction, label.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_step += 1
step_loss.append(loss.item())
step_label.extend(label.squeeze().cpu().tolist())
step_pred.extend(prediction.squeeze().detach().cpu().tolist())
if step % args.logging_step == 0:
if torch.distributed.is_initialized(): # 兼容非 DDP 的模式
step_loss = torch.tensor(step_loss).to(args.device)
step_label = torch.tensor(np.array(step_label