Distributed Trainning with torch
Torch 分布式训练(Distributed Training)是 PyTorch 用于加速大规模训练任务、提高模型性能和可扩展性的关键能力,特别适用于多GPU、多节点环境。下面是一个全面的概述,包括几种主流方式、适用场景和实现方法。
🔧 一、核心组件与后端
通信后端(Backend):
Backend | 场景 | 是否支持多机 |
---|---|---|
nccl | GPU训练(推荐) | ✅ |
gloo | CPU训练 & 调试 | ✅ |
mpi | HPC环境 | ✅ |
一般训练使用
nccl
(NVIDIA Collective Communication Library)最优。
🚀 二、主流分布式方式对比
方法 | 简介 | 优点 | 典型函数/模块 |
---|---|---|---|
DataParallel | 单机多GPU的“数据并行”封装 | 简单,易用 | torch.nn.DataParallel |
DistributedDataParallel (DDP) | 推荐的方式,适合多机多卡,性能优(重点了解) | 高效,原生支持 | `torch.nn.parallel.DistributedDataParallel |
FSDP(Fully Sharded DDP) | 模型/梯度参数切分,节省显存 | 更大模型支持 | torch.distributed.fsdp |
RPC / Pipeline Parallel | 跨层并行或任务级分布式 | 异构计算支持 | torch.distributed.rpc |
DDP(DataDistributedParallel)
一 基础原理
🔑 关键机制:AllReduce
- 在 DDP 中,每个进程(通常对应一个 GPU)拥有模型的完整副本。
- 模型的 前向传播 是独立执行的,每个进程处理不同的 mini-batch。
- 反向传播阶段,DDP 会触发
AllReduce
操作,把各个进程中的梯度做一次 全局同步,保证所有模型副本在更新时参数保持一致。
📌 AllReduce 是什么?
AllReduce 是一种通信原语,其目标是:
所有参与节点互相通信,对各自的数据进行聚合(比如求和、求平均),然后每个节点都获得这个聚合结果。
🧠 AllReduce 如何实现参数同步?
❗注意:AllReduce 同步的是梯度,不是直接同步参数。
为什么?
因为在训练中,模型参数是通过梯度更新的(如 SGD、Adam)。各进程(或GPU)只要保持 梯度一致,则在应用相同的 optimizer 更新后,各自模型参数就会保持一致。
🔁 同步流程分解
如上图所示:我们有两个进程(两张 GPU)在训练相同的模型(各自数据不同):
进程0:拥有模型副本A(参数θ₀)
进程1:拥有模型副本B(参数θ₁)
🧩 步骤详解:
① 前向传播(不通信):
- 每个进程用自己那一份数据执行 forward,计算 loss。
② 反向传播(关键):
- 每个进程开始执行
.backward()
,PyTorch 的 DDP 会在每个参数的.grad
计算完成后,触发 AllReduce 操作。
③ AllReduce(自动触发):
- 以参数
W1
为例:- 进程0得出:
grad_W1_0
- 进程1得出:
grad_W1_1
- 进程0得出:
AllReduce 计算方式如下(默认是求平均):
grad_W1 = (grad_W1_0 + grad_W1_1) / 2
- 每个进程都收到这个平均梯度
grad_W1
,并赋值到自己的param.grad
上。
④ 优化器 step(局部):
- 每个进程用自己的 optimizer 执行
optimizer.step()
,更新后的参数保持一致:
θ_new = θ_old - lr * grad_W1
二 Experiment
To use DDP, you’ll need to spawn multiple processes and create a single instance of DDP per process.
Each process will have its own copy of the model, but they’ll all work together to train the model as if it were on a single machine.
DDP registers an autograd hook for each parameter in the model. When the backward pass is run, this hook fires and triggers gradient synchronization across all processes. This ensures that each process has the same gradients, which are then used to update the model.
🧱 第一步:写一个基础 Dockerfile
prerequites is that:
⚠️ dockerhub需要提前配置镜像源信息:
修改 Docker 的配置文件 /etc/docker/daemon.json
:
https://dockerproxy.net/docs
然后重启 Docker:
sudo systemctl daemon-reexec
sudo systemctl restart docker
# Dockerfile.cpu
FROM python:3.10-slim
# 使用清华 pip 镜像作为默认源,否则pip会出错
RUN mkdir -p /root/.pip && \
echo "[global]\nindex-url = https://pypi.tuna.tsinghua.edu.cn/simple" > /root/.pip/pip.conf
# 安装系统依赖
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
git \
wget \
vim \
htop \
iputils-ping \
net-tools \
procps \
netcat-openbsd \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# 安装 PyTorch(CPU版本)
RUN pip install --upgrade pip && \
pip install torch torchvision torchaudio
# 安装其他 Python 包
RUN pip install tensorboard tqdm
# 设置工作目录
WORKDIR /workspace
Then, run:
docker build -t torch_ddp_cpu_env -f Dockerfile .
🌐 第二步:创建 Docker 网络用于通信
docker network create ddp-net
🚀 第三步:启动两个容器(模拟两台机器)
we’ll apply two virtual nodes:
# node0
docker run -dit \
--name node0 \
--hostname node0 \
--network ddp-net \
-v $PWD:/workspace \
torch_ddp_cpu_env
# node1
docker run -dit \
--name node1 \
--hostname node1 \
--network ddp-net \
-v $PWD:/workspace \
torch_ddp_cpu_env
⌨️ 第四步:Show Code
- 写一个demo脚本:
dist.init_process_group(
backend=“gloo”,
init_method=“env://”,
timeout=timedelta(seconds=600), # 默认1800,酌情调大
rank=rank,
world_size=world_size
)
backend: 底层通信机制;
init_method: 进程组初始化方式,推荐使用:✅ "env://"
通过环境变量传参(适合多机 torchrun)
timeout: 设置较大值避免不同步节点报错,默认 30 分钟,常设为 timedelta(seconds=600)
rank: 当前进程编号,范围:0 ~ world_size - 1
,必须唯一
World_size: 总进程数,等于节点数量 × 每节点进程数
# main.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import typer
from pydantic import BaseModel
app = typer.Typer()
class Config(BaseModel):
mode: str = "train" # or "infer"
backend: str = "gloo"
checkpoint_path: str = "model.pt"
epochs: int = 5
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank: int, world_size: int, backend: str):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
dist.init_process_group(backend, rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank: int, world_size: int, cfg: Config):
print(f"Rank {rank}: Starting training with backend={cfg.backend}")
setup(rank, world_size, cfg.backend)
model = ToyModel()
ddp_model = DDP(model)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
for epoch in range(cfg.epochs):
inputs = torch.randn(20, 10)
labels = torch.randn(20, 5)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
print(f"Rank {rank} Epoch {epoch} Loss: {loss.item()}")
if rank == 0:
torch.save(ddp_model.state_dict(), cfg.checkpoint_path)
print(f"Saved model to {cfg.checkpoint_path}")
cleanup()
def infer(rank: int, world_size: int, cfg: Config):
print(f"Rank {rank}: Running inference")
setup(rank, world_size, cfg.backend)
model = ToyModel()
ddp_model = DDP(model)
map_location = {'cpu': 'cpu'}
ddp_model.load_state_dict(torch.load(cfg.checkpoint_path, map_location=map_location))
ddp_model.eval()
inputs = torch.randn(10, 10)
with torch.no_grad():
outputs = ddp_model(inputs)
print(f"🔍 Rank {rank} Inference Output:\n{outputs}")
cleanup()
@app.command()
def main(
rank: int = typer.Option(..., help="Rank of the current process"),
world_size: int = typer.Option(..., help="Total number of processes"),
mode: str = "train",
backend: str = "gloo",
checkpoint_path: str = "model.pt",
epochs: int = 5,
):
cfg = Config(mode=mode, backend=backend, checkpoint_path=checkpoint_path, epochs=epochs)
if cfg.mode == "train":
train(rank, world_size, cfg)
elif cfg.mode == "infer":
infer(rank, world_size, cfg)
else:
raise ValueError("Mode must be either 'train' or 'infer'")
if __name__ == "__main__":
app()
-
写一个运行脚本
–nproc_per_node: 每台机器进程数量
–nnodes: 总共机器数量
–node_rank: 当前机器编号
–master_addr: 主节点地址(IP/主机名称)
–master_port: 主节点开方端口
#!/bin/bash
# run_ddp.sh
set -e # 一旦出错立即退出
set -x # 打印每一行命令(便于调试)
# 参数设置
MASTER_ADDR="node0"
MASTER_PORT=12345
NPROC_PER_NODE=1
NNODES=2
SCRIPT="main.py" # 或 demo_torchrun.py
# 启动 node0(rank 0)
docker exec -d node0 bash -c "
cd /workspace && \
torchrun \
--nproc_per_node=$NPROC_PER_NODE \
--nnodes=$NNODES \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
$SCRIPT --rank=0 --world-size=2 --mode=train > node0.log 2>&1
"
# 启动 node1(rank 1)
docker exec -d node1 bash -c "
cd /workspace && \
torchrun \
--nproc_per_node=$NPROC_PER_NODE \
--nnodes=$NNODES \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
$SCRIPT --rank=1 --world-size=2 --mode=train > node1.log 2>&1
"
运行该脚本即可。
What if gpu is available, given two nodes,each nodes has two gpus
# main.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from pydantic import BaseModel
import typer
app = typer.Typer()
class Config(BaseModel):
rank: int
world_size: int
mode: str = "train" # train or infer
backend: str = "nccl"
save_path: str = "./model.pt"
class ToyModel(nn.Module):
pass
def setup(rank, world_size, backend="nccl"):
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12345")
dist.init_process_group(backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank % torch.cuda.device_count())
def cleanup():
dist.destroy_process_group()
def train(args: Args):
setup(args.rank, args.world_size, args.backend)
device = torch.device(f"cuda:{args.rank % torch.cuda.device_count()}")
model = ToyModel().to(device)
ddp_model = DDP(model, device_ids=[device])
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
optimizer.zero_grad()
inputs = torch.randn(20, 10).to(device)
labels = torch.randn(20, 5).to(device)
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if args.rank == 0:
torch.save(ddp_model.state_dict(), args.save_path)
cleanup()
print(f"[Rank {args.rank}] Training complete.")
def infer(args: Args):
setup(args.rank, args.world_size, args.backend)
device = torch.device(f"cuda:{args.rank % torch.cuda.device_count()}")
model = ToyModel().to(device)
ddp_model = DDP(model, device_ids=[device])
map_location = {"cuda:0": f"cuda:{args.rank % torch.cuda.device_count()}"}
ddp_model.load_state_dict(torch.load(args.save_path, map_location=map_location))
inputs = torch.randn(5, 10).to(device)
with torch.no_grad():
outputs = ddp_model(inputs)
print(f"[Rank {args.rank}] Inference output:\n", outputs)
cleanup()
@app.command()
def main(
rank: int,
world_size: int,
mode: str = "train",
save_path: str = "./model.pt",
):
args = Args(rank=rank, world_size=world_size, mode=mode, save_path=save_path)
if args.mode == "train":
train(args)
else:
infer(args)
if __name__ == "__main__":
app()
运行脚本
#!/bin/bash
# 参数统一
NUM_NODES=2
NPROC_PER_NODE=2
MASTER_ADDR="node0_ip" # 实际的 node0 IP 地址
MASTER_PORT=12345
SCRIPT="demo_torchrun_gpu.py"
# 本机 node0 启动
torchrun --nproc_per_node=$NPROC_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
$SCRIPT &
# 远程 node1 启动(无密码 ssh 控制)
ssh user@node1_ip "cd /your/project/path && \
torchrun --nproc_per_node=$NPROC_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
$SCRIPT" &
从以上代码可以看出,DDP封装了底层的细节,如果只是简单应用,只需要对process group这些进行一个初始化就OK了,DDP自己可以做参数同步。