PyTorch Variable与动态图

本文介绍在使用PyTorch进行模型训练时如何正确地累加损失。特别指出,由于PyTorch采用动态图机制,若不正确处理损失的累加,会导致计算图不断增大并占用大量显存。文章提供了正确的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

https://mp.weixin.qq.com/s/OMjfck4jlMneGZ1NJxbjKQ

for data, label in trainloader:
    ......
    out = model(data)
    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里
    ......
    # 正确的写法:loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。

而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值