模型优化之模型剪枝

一、概述
模型剪枝按照结构划分,主要包括结构化剪枝和非结构化剪枝:
(1)结构化剪枝:剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0。
在这里插入图片描述
(2)非结构化剪枝:把权重矩阵中某个神经元节点去掉,则和神经元相连接的突触也要全部去除。相当于同时去除权重矩阵中的某一行和列。如何判断神经元节点的重要程度呢?可以通过计算神经元对应的行和列的权重值的平方和的根的大小进行排序,把排序在后面一定比例的神经元节点去掉
在这里插入图片描述
二、pytorch中模型剪枝:
Pytorch中模型的剪枝方法有三种,局部剪枝、全局剪枝和自定义剪枝。与剪枝有关的接口封装在torch.nn.utils.prune中。接下来开始演示三种剪枝在LeNet网络中的应用效果,我们首先给出LeNet网络结构。

import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

(1)局部剪枝
在本人理解就是一层一层的单独剪枝,下面代码还附有多参数多网络结构剪枝:

def part_cut(model):
    '''
    ######################################局部剪枝#########################################
    剪枝之后会产生一个mask
    剪枝api:prune.random_unstructured(layer1, name="weight", amount=0.3)
            amount:剪枝的比例
            layer1:需要剪的层对象
            name:指定剪的权重还是偏执
    剪枝固化api:prune.remove(layer1, 'weight')
            参数不用过多介绍,功能是剪枝后的模型固化(永久化)
    '''
    layer1 = model.conv1
    print("--------------------------------------剪枝前----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.random_unstructured(layer1, name="weight", amount=0.3)
    print("--------------------------------------剪枝后-----------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))
    prune.remove(layer1, 'weight')
    print("-------------------------------------模型固化后---------------------------------")
    # print(list(layer1.named_parameters()))
    # print(list(layer1.named_buffers()))


    '''-------------------------------------多参数多网络结构剪枝---------------------------------'''
    for name, module in model.named_modules():
        print(name,module)
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
            prune.remove(module, 'weight')
        # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            prune.remove(module, 'weight')

    print(dict(model.named_buffers()).keys())  # to verify that all masks exist
    return 0

(2)全局剪枝:
剪枝所占比例是按照所有参数来算的,不是按照每层的数量来算的,剪枝时候也按整体来算。

def glob_cut(model):
    '''
    全局剪枝:
    '''
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.6,
    )
    print(list(model.named_parameters()))
    print(list(model.named_buffers()))

(3)自定义剪枝
该方法不说了,饿了,要吃饭了,急的话参考下官方教程,最后有。

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小树苗m

您的打赏,是我的动力。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值