在 PyTorch 中实现 DenseNet(以经典的 DenseNet-121 为例),核心是实现它的 "密集连接" 机制 —— 每一层都与前面所有层通过通道拼接(Concatenate)直接连接。我们从基础模块开始,一步步搭建,确保你能理解每个部分的作用。
一、先明确 DenseNet 的核心结构
DenseNet 的结构可以概括为:
输入(224×224彩色图) →
初始卷积层 → 初始池化层 →
4个Dense块(每个Dense块包含多个密集连接的卷积层) →
每个Dense块后接过渡层(下采样+通道压缩) →
全局平均池化 → 全连接层(输出1000类)
其中,Dense 块(含密集连接)和过渡层是核心组件。
二、PyTorch 实现 DenseNet 的步骤
步骤 1:导入必要的库
和之前实现其他 CNN 一样,先准备好工具:
import torch # 核心库
import torch.nn as nn # 神经网络层
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader # 数据加载器
from torchvision import datasets, transforms # 图像数据处理
步骤 2:实现 DenseNet 的基础组件 —— 瓶颈层(Bottleneck)
DenseNet 的每一层都用 "瓶颈层" 设计,通过 1×1 卷积降维,避免通道数过多导致计算量爆炸:
class Bottleneck(nn.Module):
def __init__(self, in_channels, growth_rate, dropout_rate=0.0):
super(Bottleneck, self).__init__()
# 瓶颈层结构:BN → ReLU → 1×1 Conv → BN → ReLU → 3×3 Conv
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
# 1×1卷积:降维到4×growth_rate(论文推荐)
self.conv1 = nn.Conv2d(
in_channels, 4 * growth_rate,
kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(4 * growth_rate)
# 3×3卷积:输出growth_rate个通道(控制每一层的新增通道数)
self.conv2 = nn.Conv2d(
4 * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False
)
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
# x是前面所有层输出的拼接(密集连接的输入)
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out) # 1×1降维
if self.dropout is not None:
out = self.dropout(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out) # 3×3提取特征
if self.dropout is not None:
out = self.dropout(out)
# 关键:将当前层输出与输入拼接(密集连接的核心)
# 输入x是前面所有层的拼接,这里再拼上当前层输出
return torch.cat([x, out], dim=1)
通俗解释:
growth_rate
(增长率 k)是 DenseNet 的核心参数,控制每一层新增的通道数(比如 k=32,每一层输出 32 个新通道);- 1×1 卷积先将输入通道数降到
4×k
(降维,减少计算量),3×3 卷积再输出k
个通道; - 最终通过
torch.cat
将当前层输出与输入(前面所有层的特征)拼接,实现密集连接。
步骤 3:实现 Dense 块(Dense Block)
一个 Dense 块由多个 Bottleneck 层组成,所有层通过密集连接串联:
def _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate):
"""
创建一个Dense块
in_channels: 输入通道数
num_layers: 块内的Bottleneck层数
growth_rate: 增长率k
"""
layers = []
for _ in range(num_layers):
# 每个Bottleneck的输入通道数会随层数增加(因为密集连接不断拼接)
layers.append(Bottleneck(in_channels, growth_rate, dropout_rate))
# 更新输入通道数(加上新增的growth_rate个通道)
in_channels += growth_rate
return nn.Sequential(*layers), in_channels
举例:
如果输入通道 = 64,num_layers=6,growth_rate=32:
- 第 1 层输入 = 64 → 输出拼接后 = 64+32=96
- 第 2 层输入 = 96 → 输出拼接后 = 96+32=128
- ...
- 第 6 层输入 = 64+5×32=224 → 输出拼接后 = 224+32=256
最终 Dense 块输出通道 = 256
步骤 4:实现过渡层(Transition Layer)
过渡层用于连接两个 Dense 块,作用是 "下采样(尺寸减半)+ 通道压缩":
class Transition(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0.0):
super(Transition, self).__init__()
# 过渡层结构:BN → ReLU → 1×1 Conv(通道压缩) → 2×2 AvgPool(下采样)
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(
in_channels, out_channels,
kernel_size=1, stride=1, padding=0, bias=False
)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
out = self.bn(x)
out = self.relu(out)
out = self.conv(out) # 通道压缩
if self.dropout is not None:
out = self.dropout(out)
out = self.pool(out) # 下采样(尺寸减半)
return out
通道压缩逻辑:
过渡层的输出通道数 = 输入通道数 × 压缩因子 θ(论文中 θ=0.5)。例如输入 256 通道,过渡层输出 128 通道(256×0.5)。
步骤 5:搭建 DenseNet 完整网络(以 DenseNet-121 为例)
DenseNet-121 的结构是:4 个 Dense 块,分别包含 6、12、24、16 层 Bottleneck,增长率 k=32:
class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_classes=1000, dropout_rate=0.0, compression=0.5):
super(DenseNet, self).__init__()
# 初始卷积层:输出通道数=2×growth_rate(论文推荐)
in_channels = 2 * growth_rate
self.features = nn.Sequential(
nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 初始池化
)
# 构建4个Dense块和过渡层
for i, num_layers in enumerate(block_config):
# 1. 添加Dense块
dense_block, in_channels = _make_dense_block(
in_channels, num_layers, growth_rate, dropout_rate
)
self.features.add_module(f'denseblock{i+1}', dense_block)
# 2. 添加过渡层(最后一个Dense块后没有过渡层)
if i != len(block_config) - 1:
# 过渡层输出通道数=输入通道数×压缩因子
out_channels = int(in_channels * compression)
trans = Transition(in_channels, out_channels, dropout_rate)
self.features.add_module(f'transition{i+1}', trans)
in_channels = out_channels # 更新输入通道数
# 最终的BN和ReLU
self.features.add_module('bn_final', nn.BatchNorm2d(in_channels))
self.features.add_module('relu_final', nn.ReLU(inplace=True))
# 全局平均池化和全连接层
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(in_channels, num_classes)
def forward(self, x):
features = self.features(x) # 经过所有Dense块和过渡层
out = self.global_pool(features) # 全局池化
out = out.view(out.size(0), -1) # 拉平成向量
out = self.classifier(out) # 全连接层输出
return out
结构解释:
block_config=(6,12,24,16)
:4 个 Dense 块的层数,总和 6+12+24+16=58,加上初始卷积、过渡层等,总层数约 121(DenseNet-121 名称由来);compression=0.5
:过渡层的压缩因子 θ,控制通道数减少比例;- 每个 Dense 块后接过渡层(最后一个除外),实现特征图尺寸从 56×56→28×28→14×14→7×7 的逐步缩减。
步骤 6:准备数据(用 CIFAR-10 演示)
DenseNet 适合高精度分类任务,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([
transforms.Resize(256), # 缩放为256×256
transforms.RandomCrop(224), # 随机裁剪成224×224
transforms.RandomHorizontalFlip(), # 随机翻转(数据增强)
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
# 批量加载数据(DenseNet内存占用较高,batch_size适当减小)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
步骤 7:初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DenseNet-121配置:增长率32,4个Dense块,输出10类(CIFAR-10)
model = DenseNet(
growth_rate=32,
block_config=(6, 12, 24, 16),
num_classes=10,
dropout_rate=0.2 # 可选:添加dropout防止过拟合
).to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失
# 优化器:推荐用Adam,学习率0.001
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
步骤 8:训练和测试函数
DenseNet 训练时内存占用较高,训练逻辑和之前类似但需注意显存使用:
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad() # 清空梯度
output = model(data) # 模型预测
loss = criterion(output, target) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 打印进度
if batch_idx % 50 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
步骤 9:开始训练和测试
DenseNet 收敛较慢,建议训练 30-50 轮:
for epoch in range(1, 31):
train(model, train_loader, criterion, optimizer, epoch)
test(model, test_loader)
在 CIFAR-10 上,DenseNet-121 训练充分后准确率能达到 94% 以上,远高于 MobileNet 和 ShuffleNet,体现了其强大的特征融合能力。
三、完整代码总结
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 实现瓶颈层(Bottleneck)
class Bottleneck(nn.Module):
def __init__(self, in_channels, growth_rate, dropout_rate=0.0):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
# 1×1卷积降维到4×growth_rate
self.conv1 = nn.Conv2d(
in_channels, 4 * growth_rate,
kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(4 * growth_rate)
# 3×3卷积输出growth_rate个通道
self.conv2 = nn.Conv2d(
4 * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False
)
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
if self.dropout is not None:
out = self.dropout(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
if self.dropout is not None:
out = self.dropout(out)
# 密集连接:拼接输入和当前层输出
return torch.cat([x, out], dim=1)
# 2. 实现Dense块
def _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate):
layers = []
for _ in range(num_layers):
layers.append(Bottleneck(in_channels, growth_rate, dropout_rate))
in_channels += growth_rate # 更新输入通道数(累加增长率)
return nn.Sequential(*layers), in_channels
# 3. 实现过渡层
class Transition(nn.Module):
def __init__(self, in_channels, out_channels, dropout_rate=0.0):
super(Transition, self).__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(
in_channels, out_channels,
kernel_size=1, stride=1, padding=0, bias=False
)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2) # 下采样
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None
def forward(self, x):
out = self.bn(x)
out = self.relu(out)
out = self.conv(out) # 通道压缩
if self.dropout is not None:
out = self.dropout(out)
out = self.pool(out) # 尺寸减半
return out
# 4. 搭建DenseNet完整网络(DenseNet-121)
class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_classes=1000, dropout_rate=0.0, compression=0.5):
super(DenseNet, self).__init__()
# 初始卷积层
in_channels = 2 * growth_rate
self.features = nn.Sequential(
nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 构建4个Dense块和过渡层
for i, num_layers in enumerate(block_config):
# 添加Dense块
dense_block, in_channels = _make_dense_block(
in_channels, num_layers, growth_rate, dropout_rate
)
self.features.add_module(f'denseblock{i+1}', dense_block)
# 添加过渡层(最后一个块除外)
if i != len(block_config) - 1:
out_channels = int(in_channels * compression)
trans = Transition(in_channels, out_channels, dropout_rate)
self.features.add_module(f'transition{i+1}', trans)
in_channels = out_channels
# 最终的BN和ReLU
self.features.add_module('bn_final', nn.BatchNorm2d(in_channels))
self.features.add_module('relu_final', nn.ReLU(inplace=True))
# 分类部分
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(in_channels, num_classes)
def forward(self, x):
features = self.features(x)
out = self.global_pool(features)
out = out.view(out.size(0), -1) # 拉平特征
out = self.classifier(out)
return out
# 5. 准备CIFAR-10数据
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseNet(
growth_rate=32,
block_config=(6, 12, 24, 16),
num_classes=10,
dropout_rate=0.2
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# 8. 测试函数
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
# 9. 开始训练和测试
for epoch in range(1, 31):
train(model, train_loader, criterion, optimizer, epoch)
test(model, test_loader)
四、关键知识点回顾
- 核心机制:密集连接通过
torch.cat
将每一层输出与前面所有层的特征拼接,实现特征的高效复用,这是 DenseNet 精度高的关键; - 瓶颈层作用:1×1 卷积先降维(到 4×k)再用 3×3 卷积,避免密集连接导致的通道数爆炸,大幅减少计算量;
- 过渡层作用:通过 1×1 卷积压缩通道(×0.5)和 2×2 池化下采样,控制模型整体复杂度;
- 参数配置:
growth_rate
(k):每一层新增通道数,k=32 是 DenseNet-121 的标准配置;block_config
:4 个 Dense 块的层数,(6,12,24,16) 对应 121 层;
- 优缺点:精度高、参数量少,但训练时内存占用大(需存储大量中间特征),推理速度稍慢。
通过这段代码,你能亲手实现这个 "特征融合大师",感受密集连接带来的强大特征表达能力!