目录
DBLoss由三种loss组成。
1. balance bce loss
balance_bce_loss(pred=pred_prob, gt=gt_shrink, mask=gt_shrink_mask)
def balance_bce_loss(self, pred, gt, mask):
"""
pred: (b, w,h) 预测分数image, (0, 1)之间;
gt: (b, w,h) label image, 值0或者1;
mask:(b, w,h) 屏蔽背景位置,其中屏蔽位置为0,保留位置为1
"""
positive = (gt * mask) # positive前景位置,前景位置值为1,其他位置为0(利用mask屏蔽不需要训练的位置,比如字符检测时,模糊不清的字符)
negative = ((1 - gt) * mask) # negative背景位置,(1-gt是背景位置,再利用mask屏蔽不训练的背景位置