PyTorch模型进阶技巧
1. 自定义损失函数
PyTorch在torch.nn模块提供了许多常用的损失函数,如:MSELoss、L1Loss、BCELoss等,但是随着深度学习的发展或研究的需求,需要提出一些新的损失函数,因此我们需要知道如何自定义损失函数。
关于损失函数的定义,可以:
- 以函数方式定义
- 以类方式定义,推荐使用
1.1 以函数方式定义损失函数
损失函数本质上就是Python中的函数,因此可以直接定义一个函数即可,如:
def my_loss(output, target):
loss = torch.mean((output - target) ** 2)
return loss
1.2 以类方式定义
在实际工程中,使用面向对象的思想可以提高代码的可用性,因此以类方式定义损失函数更加常用,因此这里观察torch中的Loss函数的继承关系,可以发现它们部分继承自
_loss
,部分继承自
_WeightedLoss
,而
_WeightedLoss
继承自
_loss
,
_loss
继承自nn.Module。
因此,本质上可以把Loss函数当作神经网络的一层对待,同样地,自定义损失函数类需要继承nn.Module,下面以DiceLoss讲解自定义损失函数。
Dice Loss是一种在分割领域常见的损失函数,定义如下:
D
S
C
=
2
∣
X
∩
Y
∣
∣
X
∣
∣
Y
∣
DSC = \frac{2\left|X \cap Y\right|}{|X|+|Y|}
D
SC
=
∣
X
∣
∣
Y
∣
2
∣
X
∩
Y
∣
实现代码如下:
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
# smooth is used to prevent denominator from being 0
i