在PyTorch中构建神经网络时,我们通常需要将多个层(layers)组合起来。nn.Sequential 和 nn.ModuleList 是两种最常用、也最容易混淆的模块容器。理解它们的区别与适用场景,是构建我们所需要的模型的关键。
本文将深入探讨这两个容器:
nn.Sequential:如何自动化构建模型。
nn.ModuleList:为何可以构建复杂模型。何时以及如何选择最适合需求的容器。
一、 nn.Sequential:
想象一下工厂的流水线,产品按顺序流过。nn.Sequential 就是这样一个类似流水线一样的容器,它将我们传入的模块(层)严格按照顺序串联起来。
核心特性:
-
自动化 forward:这是 nn.Sequential 最大的特点。我们不需要手动编写 forward 函数,它会自动生成一个,让数据依次流经所有内部模块。
-
顺序固定:模块的执行顺序在定义时就已确定,无法在运行时动态改变。
-
简洁高效:对于结构简单,线性的网络(如标准的多层感知机MLP、VGG等),使用 nn.Sequential 会让代码干净、易读。
工作原理与示例
我们来看看 nn.Sequential
关键部分的代码,以理解其工作原理。
importtorch.nn as nn
fromcollections importOrderedDict
classSequential(nn.Module):
def__init__(self, *args):
super().__init__()
# 如果传入有序字典
iflen(args) ==1andisinstance(args[0], OrderedDict):
forkey, module inargs[0].items():
self.add_module(key, module)
# 如果传入模块
else:
foridx, module inenumerate(args):
self.add_module(str(idx), module) # 模块会自动命名为 "0", "1"...
defforward(self, input):
# 遍历所有子模块
formodule inself: # self 像列表迭代
# 将上一个模块的输出作为下一个模块的输入
input=module(input)
returninput
从源码可以看出:
1.__init__
负责将你传入的模块注册到容器中。你可以传入零散的模块,也可以传入一个有序字典来给每个模块命名。
2.forward
方法的实现极为简单:一个 for 循环,依次调用每个子模块。
使用方法举例:
importtorch
importtorch.nn as nn
# 一:直接传入模块
model_1 =nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
# 二:使用 OrderedDict 给模块命名
model_2 =nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20, 64, 5)),
('relu2', nn.ReLU())
]))
# 两种方式效果一样,但方法2的可读性更好,可以通过名字访问层
# print(model_2.conv1)
# 使用模型
input =torch.randn(1, 1, 28, 28)
output =model_1(input)
print(output.shape)
我个人认为当模型是线性传递时,nn.Sequential 非常合适。
二、 nn.ModuleList:
如果 nn.Sequential 是流水线,那么 nn.ModuleList 更像一个保存模块工具箱。它帮我们保管好所有的模块,但具体哪个模块先用、哪个后用、哪个工具用几次,完全由我们自己决定。
nn.ModuleList 表现得像一个Python的列表(list),但它有一个至关重要的特性:它能正确地向PyTorch注册内部的所有模块。
核心特性:
不实现 forward:它本身没有 forward 方法。需要在自己的模型中手动编写 forward 逻辑,并从中调用模块。
参数可见性:这是它与普通Python列表的根本区别。被 nn.ModuleList 包含的模块,其参数(权重和偏置)能够被PyTorch自动识别和管理。这意味着当调用 model.parameters() 或 model.to(device) 时,PyTorch会正确地处理这些模块。
高度灵活:正因为需要手动控制,你可以实现任何复杂的逻辑,如跳跃连接(Skip Connection)、分支网络、循环、条件判断等。
为什么不能用普通的Python列表?
我用一个简单的例子来证明:
importtorch
importtorch.nn as nn
classMyModelWithPythonList(nn.Module):
def__init__(self):
super().__init__()
# 使用Python 列表
self.layers =[nn.Linear(10, 10) for_ inrange(5)]
defforward(self, x):
forlayer inself.layers:
x =layer(x)
return x
classMyModelWithModuleList(nn.Module):
def__init__(self):
super().__init__()
# 使用 nn.ModuleList
self.layers =nn.ModuleList(
[nn.Linear(10, 10) for_ inrange(5)]
)
defforward(self, x):
forlayer inself.layers:
x =layer(x)
return x
model_py_list =MyModelWithPythonList()
model_module_list =MyModelWithModuleList()
# 检查参数
print("Parameters in
MyModelWithPythonList:", list(model_py_list.parameters()))
# 输出:[] 空的
print("Parameters in
MyModelWithModuleList:", len(list(model_module_list.parameters())))
# 输出:10 (5个Linear层,每层有权重和偏置,5X2=10)
这个对比清晰地表明:普通列表中的模块对PyTorch是“隐形”的,它们的参数不会被收集,因此在训练时也无法被优化器更新。
使用示例:实现残差连接
假设我们要构建一个模型,其中包含非线性的数据流,比如残差连接(Residual Connection)。这是 nn.Sequential 无法直接完成的任务。
classComplexModel(nn.Module):
def__init__(self, n_layers=6):
super().__init__()
self.layers =nn.ModuleList([nn.Linear(20, 20) for_ inrange(n_layers)])
self.activation =nn.ReLU()
defforward(self, x):
# 在 forward 中实现复杂的逻辑
fori, layer inenumerate(self.layers):
x =layer(x)
x =self.activation(x)
# 每两层添加一个残差连接
ifi %2==1:
x =x +x # 模拟残差
returnx
complex_model =ComplexModel()
dummy_input =torch.randn(1, 20)
output =complex_model(dummy_input)
print(len(list(complex_model.parameters())))
print(output.shape)
nn.ModuleList让我们保存模块,然后在 forward 函数中自由地组合排序它们,实现任意复杂的计算图。
三、 对比总结:如何选择?
选择指南:
特性 |
nn.Sequential |
nn.ModuleList |
forward 实现 |
自动,线性的数据流 |
手动,需要在父模块中定义 |
灵活性 |
低,顺序固定 |
高,支持任意逻辑(分支、跳跃、循环) |
主要用途 |
简单的线性堆叠模型或模块块 |
需要复杂数据流的定制化模型 |
模块访问 |
可通过索引或(若已命名)名字访问 |
类似列表,通过索引访问 (model.layers[i]) |
代码简洁度 |
对于简单模型,非常简洁 |
需要编写更多 forward 代码 |
首选 nn.Sequential:如果模型是一个简单的、从头到尾的线性堆叠,这能让代码更简洁、意图更明确。
必须用 nn.ModuleList 的情况:
1.跳跃/残差连接:如ResNet,数据流需要在不同层之间交互。
2.多输入/多输出/分支结构:如Inception网络,输入被送到不同的处理路径,最后再合并。
3.循环使用模块:需要在 forward 中多次、以不同方式使用同一个模块。
4.动态控制:需要根据输入的某些特性来决定调用哪个模块。
最后,两者并非互斥,而是可以组合使用。在一个复杂的模型中,完全可以使用nn.ModuleList 来管理主要的模块块(Blocks),而每一个独立的块内部则可以使用 nn.Sequential 来定义。这是一种非常常见且高效的方法。