前言
在 PyTorch 中,nn.Container 是一个基类,它并不直接提供实例化对象的功能,而是作为其他容器类(如 nn.Module, nn.Sequential, nn.ModuleList, nn.ModuleDict)的基础。这些容器类在构建和组织神经网络时起着至关重要的作用。本文将详细讲解这些容器的原理、原型、运行实例,并最后进行总结。
函数
函数原理
nn.Module
nn.Module 是所有神经网络模块的基类。当你定义一个神经网络模型时,你需要继承这个类。nn.Module 包含了一些基本属性和方法,比如 parameters() 和 buffers(),用于管理模型中的参数和缓冲区(如批量归一化层的均值和方差)。此外,它还定义了 forward() 方法,这是实现前向传播的关键。
nn.Sequential
nn.Sequential 是一个有序容器,可以包含多个子模块(layer),并自动实现这些子模块的前向传播。你只需要按照顺序添加模块,nn.Sequential 会按照添加的顺序自动调用它们的 forward() 方法。这使得模型构建变得非常简单和直观。
nn.ModuleList
nn.ModuleList 是一个持有子模块的列表,它本身也是一个 nn.Module。与普通的 Python 列表不同,nn.ModuleList 中的模块会被正确地注册,并能在整个模块树中可见。这使得在模型中包含多个相同或不同的模块时更加方便。
nn.ModuleDict
nn.ModuleDict 是一个持有子模块的字典,类似于 nn.ModuleList,但提供了通过键来访问子模块的功能。这对于需要参数化模型某些部分的场景非常有用,比如动态选择激活函数。
原型
nn.Module
class torch.nn.Module(object):
def __init__(self):
pass
def forward(self, *input):
raise NotImplementedError
nn.Sequential
class torch.nn.Sequential(Module):
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
s