环境
- python 3.9
- numpy 1.24.1
- pytorch 2.0.0+cu117
一、pruning参考
要实现自己的修剪函数,可以通过将BasePruningMethod基类子类化来扩展nn.utils.prune模块,方法与所有其他修剪方法相同。基类为您实现以下方法:__call__、apply_mask、apply、prune和remove。除了一些特殊情况外,您不应该为新的修剪技术重新实现这些方法。然而,您必须实现__init__(构造函数)和compute_mask(关于如何根据修剪技术的逻辑计算给定张量的mask的说明)。此外,您必须指定该技术实现的修剪类型(支持的选项包括全局、结构化和非结构化)。在迭代应用修剪的情况下,需要这一点来确定如何组合mask。换句话说,当修剪预编辑参数时,当前修剪技术预期作用于参数的未修剪部分。指定PRUNING_TYPE将使PruningContainer(处理修剪mask的迭代应用)能够正确标识要修剪的参数的切片。
BasePruningMethod:文档地址
BasePruningMethod.compute_mask: