pytorch节省显存小技巧

本文分享了使用PyTorch进行文本多分类时遇到的显存溢出问题及解决策略,包括inplace操作、float16混合精度计算、模型拆分及checkpoint特性,有效降低显存消耗。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用pytorch进行文本多分类问题时遇到了显存out of memory的情况,实验了多种方法,主要比较有效的有两种:
1、尽可能使用inplace操作,比如relu可以使用inplace=True
进一步将BN归一化和激活层Relu打包成inplace,在BP的时候再重新计算。
代码与论文参考
mapillary/inplace_abn
efficient_densenet_pytorch
效果:可以减少一半的显存
原理:
在大多数的深度网络的前向传播中,都有BN-Activation-Conv这样的网络结构,就必须要存储归一化的输入和全卷机层的输入。这是有必要的,因为反向传播需要输入计算梯度。通过重写BN的反向传播步骤,使用ABN代替BN-Activation序列,可以不用存储BN的输入(可通过激活函数的输出即全卷机层的输入反推),节省50%的显存。
2、使用float16精度混合计算,利用NVIDIA 的apex,也能减少50%的显存。但是有一些操作不安全如mean、sum等
官方代码参考
NVIDIA apex
3、pytorch1.0提供了模型拆分成2部分在2张卡上运行的方案
pytorch官网多卡例子
4、使用pytorch1.0的checkpoint特性,可以减少90%的显存
ckeckpoint通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算。ckeckpoint不会保存中间的激活参数,而是通过反向传播时重新计算他们。
具体可见我的GitHub一个文本分类项目
我的GitHub

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值