PyTorch核心容器:nn.Sequential与nn.ModuleList深度解析

在PyTorch中构建神经网络时,我们通常需要将多个层(layers)组合起来。nn.Sequential 和 nn.ModuleList 是两种最常用、也最容易混淆的模块容器。理解它们的区别与适用场景,是构建我们所需要的模型的关键。

本文将深入探讨这两个容器:

nn.Sequential:如何自动化构建模型。

nn.ModuleList:为何可以构建复杂模型。何时以及如何选择最适合需求的容器。

一、 nn.Sequential: 

想象一下工厂的流水线,产品按顺序流过。nn.Sequential 就是这样一个类似流水线一样的容器,它将我们传入的模块(层)严格按照顺序串联起来。

核心特性:

  1. 自动化 forward:这是 nn.Sequential 最大的特点。我们不需要手动编写 forward 函数,它会自动生成一个,让数据依次流经所有内部模块。

  2. 顺序固定:模块的执行顺序在定义时就已确定,无法在运行时动态改变。

  3. 简洁高效:对于结构简单,线性的网络(如标准的多层感知机MLP、VGG等),使用 nn.Sequential 会让代码干净、易读。

工作原理与示例

我们来看看 nn.Sequential关键部分的代码,以理解其工作原理。

importtorch.nn as nn

fromcollections importOrderedDict

 

classSequential(nn.Module):

    def__init__(self, *args):

        super().__init__()

        # 如果传入有序字典

        iflen(args) ==1andisinstance(args[0], OrderedDict):

            forkey, module inargs[0].items():

                self.add_module(key, module)

        # 如果传入模块

        else:

            foridx, module inenumerate(args):

                self.add_module(str(idx), module) # 模块会自动命名为 "0", "1"...

 

    defforward(self, input):

        # 遍历所有子模块

        formodule inself: # self 像列表迭代

            # 将上一个模块的输出作为下一个模块的输入

            input=module(input)

        returninput

从源码可以看出:

1.__init__负责将你传入的模块注册到容器中。你可以传入零散的模块,也可以传入一个有序字典来给每个模块命名。

2.forward方法的实现极为简单:一个 for 循环,依次调用每个子模块。

使用方法举例:

importtorch

importtorch.nn as nn

 

# 一:直接传入模块

model_1 =nn.Sequential(

    nn.Conv2d(1, 20, 5),

    nn.ReLU(),

    nn.Conv2d(20, 64, 5),

    nn.ReLU()

)

 

# 二:使用 OrderedDict 给模块命名

model_2 =nn.Sequential(OrderedDict([

    ('conv1', nn.Conv2d(1, 20, 5)),

    ('relu1', nn.ReLU()),

    ('conv2', nn.Conv2d(20, 64, 5)),

    ('relu2', nn.ReLU())

]))

 

# 两种方式效果一样,但方法2的可读性更好,可以通过名字访问层

# print(model_2.conv1)

 

# 使用模型

input =torch.randn(1, 1, 28, 28)

output =model_1(input)

print(output.shape)

我个人认为当模型是线性传递时,nn.Sequential 非常合适。

二、 nn.ModuleList:

如果 nn.Sequential 是流水线,那么 nn.ModuleList 更像一个保存模块工具箱。它帮我们保管好所有的模块,但具体哪个模块先用、哪个后用、哪个工具用几次,完全由我们自己决定。

nn.ModuleList 表现得像一个Python的列表(list),但它有一个至关重要的特性:它能正确地向PyTorch注册内部的所有模块。

核心特性:

不实现 forward:它本身没有 forward 方法。需要在自己的模型中手动编写 forward 逻辑,并从中调用模块。

参数可见性:这是它与普通Python列表的根本区别。被 nn.ModuleList 包含的模块,其参数(权重和偏置)能够被PyTorch自动识别和管理。这意味着当调用      model.parameters() 或 model.to(device) 时,PyTorch会正确地处理这些模块。

高度灵活:正因为需要手动控制,你可以实现任何复杂的逻辑,如跳跃连接(Skip Connection)、分支网络、循环、条件判断等。

为什么不能用普通的Python列表?

我用一个简单的例子来证明:

importtorch

importtorch.nn as nn

classMyModelWithPythonList(nn.Module):

    def__init__(self):

        super().__init__()

        # 使用Python 列表

        self.layers =[nn.Linear(10, 10) for_ inrange(5)]

    defforward(self, x):

        forlayer inself.layers:

            x =layer(x)

        return x

 
classMyModelWithModuleList(nn.Module):

    def__init__(self):

        super().__init__()

        # 使用 nn.ModuleList

        self.layers =nn.ModuleList(

[nn.Linear(10, 10) for_ inrange(5)]

)

 
    defforward(self, x):

        forlayer inself.layers:

            x =layer(x)

        return x

 
model_py_list =MyModelWithPythonList()

model_module_list =MyModelWithModuleList()


# 检查参数

print("Parameters in
  MyModelWithPythonList:", list(model_py_list.parameters()))

# 输出:[]  空的

print("Parameters in
  MyModelWithModuleList:", len(list(model_module_list.parameters())))

# 输出:10 (5个Linear层,每层有权重和偏置,5X2=10)

这个对比清晰地表明:普通列表中的模块对PyTorch是“隐形”的,它们的参数不会被收集,因此在训练时也无法被优化器更新。

使用示例:实现残差连接

假设我们要构建一个模型,其中包含非线性的数据流,比如残差连接(Residual Connection)。这是 nn.Sequential 无法直接完成的任务。

classComplexModel(nn.Module):

    def__init__(self, n_layers=6):

        super().__init__()

        self.layers =nn.ModuleList([nn.Linear(20, 20) for_ inrange(n_layers)])

        self.activation =nn.ReLU()

defforward(self, x):

        # 在 forward 中实现复杂的逻辑

        fori, layer inenumerate(self.layers):

            x =layer(x)

            x =self.activation(x)

             

            # 每两层添加一个残差连接 

            ifi %2==1:

                x =x +x  # 模拟残差

        returnx

complex_model =ComplexModel()

dummy_input =torch.randn(1, 20)

output =complex_model(dummy_input)

 

print(len(list(complex_model.parameters())))

print(output.shape)

nn.ModuleList让我们保存模块,然后在 forward 函数中自由地组合排序它们,实现任意复杂的计算图。

三、 对比总结:如何选择?

选择指南:

特性

nn.Sequential

nn.ModuleList

forward 实现

自动,线性的数据流

手动,需要在父模块中定义

灵活性

低,顺序固定

高,支持任意逻辑(分支、跳跃、循环)

主要用途

简单的线性堆叠模型或模块块

需要复杂数据流的定制化模型

模块访问

可通过索引或(若已命名)名字访问

类似列表,通过索引访问 (model.layers[i])

代码简洁度

对于简单模型,非常简洁

需要编写更多 forward 代码

首选 nn.Sequential:如果模型是一个简单的、从头到尾的线性堆叠,这能让代码更简洁、意图更明确。

必须用 nn.ModuleList 的情况:

1.跳跃/残差连接:如ResNet,数据流需要在不同层之间交互。

2.多输入/多输出/分支结构:如Inception网络,输入被送到不同的处理路径,最后再合并。

3.循环使用模块:需要在 forward 中多次、以不同方式使用同一个模块。

4.动态控制:需要根据输入的某些特性来决定调用哪个模块。

最后,两者并非互斥,而是可以组合使用。在一个复杂的模型中,完全可以使用nn.ModuleList 来管理主要的模块块(Blocks),而每一个独立的块内部则可以使用 nn.Sequential 来定义。这是一种非常常见且高效的方法。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值