使用add_module添加网络子模块

add_module的功能为Module添加一个子module,对应名字为name。使用方式如下:

add_module(name, module)

其中name为子模块的名字,使用这个名字可以访问特定的子module,module为我们自定义的子module。

一般情况下子module都是在A.init(self)中定义的,比如A中一个卷积子模块self.conv1 = torch.nn.Conv2d(…)。此时,这个卷积模块在A的名字其实是’conv1’。

对比之下,add_module()函数就可以在A.init(self)以外定义A的子模块。如定义同样的卷积子模块,可以通过A.add_module(‘conv1’, torch.nn.Conv2d(…))。

直接看代码会更直观,举例如下:

普通add_module举例

from torch import nn
from torchsummary import summary

class Net1(nn.Module):
    def __init__(self):
        super(Net1,self).__init__()
        self.conv_1 = nn.Conv2d(3,6,3)
        self.add_module('conv_2', nn.Conv2d(6,12,3))
        self.conv_3 = nn.Conv2d(12,24,3)
        
    def forward(self,x):
          x = self.conv_1(x)
          x = self.conv_2(x) # 使用name访问
          x = self.conv_3(x)
          return x
    
model = Net1()
print(model)
model.to('cuda')
summary(model,(3,128,128)) 

输出:

Net1(
  (conv_dd): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
  (conv_3): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 6, 126, 126]             168
            Conv2d-2         [-1, 12, 124, 124]             660
            Conv2d-3         [-1, 24, 122, 122]           2,616
================================================================
Total params: 3,444
Trainable params: 3,444
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 4.86
Params size (MB): 0.01
Estimated Total Size (MB): 5.06
----------------------------------------------------------------
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值