Pytorch 中的 torch.nn.Identity( ) 的作用

torch.nn.Identity()作为恒等函数,用于在深度学习模型中保持输入输出不变。在示例中,它替换预训练模型的最后一层,允许提取特征并自定义新的分类层。这在迁移学习场景中常见,以利用预训练模型的特征表示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 torch.nn.Identity( ) 的作用

torch.nn.Identity( ) 相当于一个恒等函数

f(x) = x

  • 这个函数相当于输入什么就输出什么, 可以用在对已经设计好模型结构的修改, 比如模型的最后一层是 1000 分类, 我们可以将最后一层用 nn.Identity( ) 替换掉, 得到它之前学习的特征, 然后再自己设计最后一层的结构

  • 在迁移学习中经常使用

2 示例

import torch
from torch import nn
from torch.nn import NLLLoss
import timm


class MiniModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False, backbone_ckpt=None):
        
        super().__init__()
        self.backbone = timm.crear_model(backbone, pretrained=pretrained, checkpoint_path=backbone_ckpt)
        self.head = nn.Linear(self.backbone.get_classifier().in_features, num_class)
        # 替代最后一层的全连接网络
        self.backbone.head.fc = nn.Identity()
        self.loss_fn = NLLLoss()
        
    def forward(self, image, label):
        embed = self.backbone(image)
        logit = self.head(embed)
        
        if label is not None:
            logit_logsoftmax = torch.log_softmax(logit, 1)
            loss = self.loss_fn(logit_logsoftmax, label)
            return {"prediction": logit, "loss": loss}
        return {"prediction": logit}

### 解释 `nn.Identity` 的含义和用法 在 PyTorch 中,`nn.Identity` 是一个特殊的层,它不会对输入数据执行任何操作,而是简单地返回未修改的输入作为输出。这看似无意义的操作实际上有其特定用途。 当构建神经网络模型时,有时需要在网络结构中保留某个位置用于可能的未来扩展或者条件分支逻辑。此时可以暂时放置 `nn.Identity` 层来占位,在后续开发过程中再替换为实际的功能模块[^1]。 此外,如果希望绕过某些层而不改变整体架构,则可以用 `nn.Identity` 来实现这一点。例如,在迁移学习场景下调整预训练模型的部分组件时非常有用[^2]。 对于给定代码片段中的情况: ```python if isinstance(c2, list): m_ = m m_.backbone = True else: m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module t = str(m)[8:-2].replace('__main__.', '') # module type m.np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t # attach index, 'from' index, type ``` 这里并没有直接展示 `nn.Identity` 的实例化过程,但如果在这个上下文中想要跳过某一层处理或将一部分网络设置为空操作,就可以引入 `nn.Identity()` 实现该目的。 #### 示例代码说明如何使用 `nn.Identity` 假设有一个简单的卷积神经网络定义如下: ```python import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 6, kernel_size=5) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(6 * 14 * 14, 10) # 使用 Identity 占位符代替其他更复杂的层 self.skip_layer = nn.Identity() def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.skip_layer(x) # 这里会直接传递 x 不做额外变换 x = x.view(-1, 6 * 14 * 14) x = F.relu(self.fc1(x)) return x ``` 此例子展示了如何利用 `nn.Identity` 创建一个不进行任何运算但仍保持连接性的层,这对于调试或简化现有复杂模型特别方便。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值