PyTorch深度学习:模型优化与进阶技巧探索

PyTorch模型进阶技巧

1. 自定义损失函数

PyTorch在torch.nn模块提供了许多常用的损失函数,如:MSELoss、L1Loss、BCELoss等,但是随着深度学习的发展或研究的需求,需要提出一些新的损失函数,因此我们需要知道如何自定义损失函数。

关于损失函数的定义,可以:

  • 以函数方式定义
  • 以类方式定义,推荐使用
1.1 以函数方式定义损失函数

损失函数本质上就是Python中的函数,因此可以直接定义一个函数即可,如:

def my_loss(output, target):
	loss = torch.mean((output - target) ** 2)
    return loss

1.2 以类方式定义

在实际工程中,使用面向对象的思想可以提高代码的可用性,因此以类方式定义损失函数更加常用,因此这里观察torch中的Loss函数的继承关系,可以发现它们部分继承自
_loss
,部分继承自
_WeightedLoss
,而
_WeightedLoss
继承自
_loss

_loss
继承自nn.Module。

因此,本质上可以把Loss函数当作神经网络的一层对待,同样地,自定义损失函数类需要继承nn.Module,下面以DiceLoss讲解自定义损失函数。

Dice Loss是一种在分割领域常见的损失函数,定义如下:

D

S

C

=

2

X

Y

X

Y

DSC = \frac{2\left|X \cap Y\right|}{|X|+|Y|}

D

SC

=

X

Y

2

X

Y

实现代码如下:

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()
    
    def forward(self, inputs, targets, smooth=1):
        # smooth is used to prevent denominator from being 0
        i
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值