0. 快速路标
- 理论主线:经验风险最小化 → 最大似然/贝叶斯视角 → 具体任务的似然假设
- 工程主线:选型清单 → 数值稳定 → 类不平衡/多任务权重 → 监控与诊断
1. 基础:经验风险最小化(ERM)
给定数据集 D={(xi,yi)}i=1n\mathcal{D}=\{(x_i,y_i)\}_{i=1}^nD={(xi,yi)}i=1n,模型 fθ(x)f_\theta(x)fθ(x) 输出预测 y^\hat{y}y^。
经验风险:
minθ R(θ)=1n∑i=1nL(y^i,yi) \min_\theta \; \mathcal{R}(\theta)=\frac{1}{n}\sum_{i=1}^{n}\mathcal{L}\big(\hat{y}_i, y_i\big) θminR(θ)=n1i=1∑nL(y^i,yi)
梯度下降(以 Adam 为例)最小化 R(θ)\mathcal{R}(\theta)R(θ)。核心:选择与任务/数据噪声匹配的 L\mathcal{L}L。
2. 概率视角:从最大似然到常用损失
2.1 回归:高斯噪声 → MSE
假设:
y∣x∼N(μ=fθ(x), σ2) y|x \sim \mathcal{N}\big(\mu=f_\theta(x),\,\sigma^2\big) y∣x∼N(μ=fθ(x),σ2)
负对数似然(NLL):
−logp(y∣x)=(y−μ)22σ2+12log(2πσ2) -\log p(y|x)=\frac{(y-\mu)^2}{2\sigma^2}+\frac{1}{2}\log(2\pi\sigma^2) −logp(y∣x)=2σ2(y−μ)2+21log(2πσ2)
σ\sigmaσ 视为常数时,最小化 NLL 等价最小化 MSE:
LMSE=12(y−y^)2 \mathcal{L}_{\text{MSE}}=\tfrac{1}{2}(y-\hat{y})^2 LMSE=21(y−y^)2
2.2 回归:拉普拉斯噪声 → MAE/分位数回归
若 y∣x∼Laplace(μ,b)y|x \sim \text{Laplace}(\mu,b)y∣x∼Laplace(μ,b):
−logp(y∣x)=∣y−μ∣b+log(2b) -\log p(y|x)=\frac{|y-\mu|}{b}+\log(2b) −logp(y∣x)=b∣y−μ∣+log(2b)
得到 MAE:LMAE=∣y−y^∣\mathcal{L}_{\text{MAE}}=|y-\hat{y}|LMAE=∣y−y^∣。
更一般的分位数回归(Pinball):
Lτ(y^,y)=max(τ(y−y^), (τ−1)(y−y^)) \mathcal{L}_\tau(\hat{y},y)=\max\big(\tau(y-\hat{y}),\,(\tau-1)(y-\hat{y})\big) Lτ(y^,y)=max(τ(y−y^),(τ−1)(y−y^))
τ=0.5\tau=0.5τ=0.5 退化为 MAE,中位数最优。
2.3 Huber / Smooth L1(鲁棒回归)
Lδ(e)={12e2,∣e∣≤δδ(∣e∣−12δ),∣e∣>δ,e=y−y^ \mathcal{L}_\delta(e)= \begin{cases} \tfrac{1}{2}e^2, & |e|\le \delta\\ \delta(|e|-\tfrac{1}{2}\delta), & |e|>\delta \end{cases},\quad e=y-\hat{y} Lδ(e)={21e2,δ(∣e∣−21δ),∣e∣≤δ∣e∣>δ,e=y−y^
小误差像 MSE,大误差像 MAE(抗离群)。
2.4 二分类:伯努利似然 → BCE
令 y∈{0,1}y\in\{0,1\}y∈{0,1},p=σ(z)p=\sigma(z)p=σ(z)(logit z=fθ(x)z=f_\theta(x)z=fθ(x)):
LBCE=−[ylogp+(1−y)log(1−p)] \mathcal{L}_{\text{BCE}}=-\big[y\log p + (1-y)\log(1-p)\big] LBCE=−[ylogp+(1−y)log(1−p)]
工程上用 BCEWithLogits(直接吃 logits,数值更稳)。
2.5 多分类:多项式似然 → 交叉熵
Softmax pk=ezk∑jezjp_k=\frac{e^{z_k}}{\sum_j e^{z_j}}pk=∑jezjezk,真实标签 one-hot:
LCE=−∑kyklogpk=−logpy \mathcal{L}_{\text{CE}}=-\sum_{k}y_k\log p_k = -\log p_{y} LCE=−k∑yklogpk=−logpy
Log-Sum-Exp 稳定化:
log∑jezj=m+log∑jezj−m, m=maxjzj \log\sum_j e^{z_j}=m+\log\sum_j e^{z_j-m},\; m=\max_j z_j logj∑ezj=m+logj∑ezj−m,m=jmaxzj
2.6 计数回归:泊松/负二项
泊松回归(λ=efθ(x)\lambda=e^{f_\theta(x)}λ=efθ(x)):
LPoisson=λ−ylogλ+log(y!) \mathcal{L}_{\text{Poisson}}=\lambda - y\log\lambda + \log(y!) LPoisson=λ−ylogλ+log(y!)
过度离散可用负二项(额外引入离散度参数)。
3. 分类的扩展:Focal / Label Smoothing / Brier
- Focal Loss(难样本更大权重,适合长尾):
LFocal=−(1−pt)γlogpt,pt={py=11−py=0 \mathcal{L}_{\text{Focal}}=-(1-p_t)^\gamma \log p_t,\quad p_t= \begin{cases} p & y=1\\ 1-p & y=0 \end{cases} LFocal=−(1−pt)γlogpt,pt={p1−py=1y=0
γ∈[1,2]\gamma\in[1,2]γ∈[1,2],可配 α\alphaα 做类权重。
-
Label Smoothing:把 one-hot 的 yyy 变为
y~=(1−ε)⋅one_hot+ε/K\tilde{y}=(1-\varepsilon)\cdot \text{one\_hot} + \varepsilon/Ky~=(1−ε)⋅one_hot+ε/K,缓解过拟合、提升校准。 -
Brier Score(校准友好):
LBrier=∑k(pk−yk)2 \mathcal{L}_{\text{Brier}}=\sum_k (p_k - y_k)^2 LBrier=k∑(pk−yk)2
更关注概率校准而非极端拉大间隔。
4. 排序与度量学习
4.1 Hinge / SVM
二分类 margin:
Lhinge=max(0,1−y⋅z), y∈{−1,+1} \mathcal{L}_{\text{hinge}}=\max(0, 1 - y\cdot z),\; y\in\{-1,+1\} Lhinge=max(0,1−y⋅z),y∈{−1,+1}
鼓励 y⋅z≥1y\cdot z \ge 1y⋅z≥1。
4.2 Pairwise / Listwise 排序
- RankNet(成对逻辑斯蒂):
L=log(1+exp(−(s+−s−))) \mathcal{L}=\log\big(1+\exp(-(s^+ - s^-))\big) L=log(1+exp(−(s+−s−)))
- LambdaRank/LambdaMART:基于 NDCG 的梯度近似,工程常用于学习排序。
4.3 对比学习/度量学习
- Contrastive:
L=y⋅d2+(1−y)⋅max(0,m−d)2 \mathcal{L}=y\cdot d^2 + (1-y)\cdot \max(0, m-d)^2 L=y⋅d2+(1−y)⋅max(0,m−d)2
- Triplet:
L=max(0,d(a,p)−d(a,n)+m) \mathcal{L}=\max\big(0, d(a,p) - d(a,n) + m\big) L=max(0,d(a,p)−d(a,n)+m)
- InfoNCE(对比自监督常用):
L=−logexp(sim(q,k+)/τ)∑k∈Kexp(sim(q,k)/τ) \mathcal{L}=-\log \frac{\exp(\text{sim}(q,k^+)/\tau)}{\sum_{k\in\mathcal{K}}\exp(\text{sim}(q,k)/\tau)} L=−log∑k∈Kexp(sim(q,k)/τ)exp(sim(q,k+)/τ)
5. 密集预测(分割/检测/关键点)
5.1 语义分割
- Dice(软):
Dice=2∑(p⋅g)∑p+∑g,L=1−Dice \text{Dice}= \frac{2\sum (p\cdot g)}{\sum p + \sum g},\quad \mathcal{L}=1-\text{Dice} Dice=∑p+∑g2∑(p⋅g),L=1−Dice
- Jaccard/IoU(软):IoU=∑(p⋅g)∑p+∑g−∑(p⋅g)\text{IoU}=\frac{\sum(p\cdot g)}{\sum p + \sum g - \sum(p\cdot g)}IoU=∑p+∑g−∑(p⋅g)∑(p⋅g)
- Tversky(控制 FP/FN 权衡):
TV=∑TP∑TP+α∑FP+β∑FN, L=1−TV \text{TV}=\frac{\sum TP}{\sum TP + \alpha\sum FP + \beta\sum FN},\; \mathcal{L}=1-\text{TV} TV=∑TP+α∑FP+β∑FN∑TP,L=1−TV
- Focal Tversky:L=(1−TV)γ\mathcal{L}=(1-\text{TV})^\gammaL=(1−TV)γ
实战常用:BCEWithLogits + λ·Dice(小目标更稳)。
5.2 目标检测(框回归)
-
Smooth L1/Huber:经典稳定
-
IoU/GIoU/DIoU/CIoU(直接优化重叠与几何关系)
- LIoU=1−IoU\mathcal{L}_{\text{IoU}}=1-\text{IoU}LIoU=1−IoU
- GIoU:加入最小包围框 CCC,L=1−IoU+∣C∖(A∪B)∣∣C∣\mathcal{L}=1-\text{IoU}+\frac{|C\setminus (A\cup B)|}{|C|}L=1−IoU+∣C∣∣C∖(A∪B)∣
- DIoU/CIoU:考虑中心距与宽高比,使收敛更快更稳
5.3 关键点
- Heatmap MSE(密度回归)或 Smooth L1(坐标回归)
- Wing/L1-Laplace:对大/小误差给予不同抑制
6. 生成建模与贝叶斯
-
VAE:ELBO=Eq[logp(x∣z)]⏟重建项−DKL(q(z∣x)∥p(z))⏟先验正则\text{ELBO}=\underbrace{\mathbb{E}_{q}[\log p(x|z)]}_{\text{重建项}} - \underbrace{D_{\text{KL}}(q(z|x)\|p(z))}_{\text{先验正则}}ELBO=重建项Eq[logp(x∣z)]−先验正则DKL(q(z∣x)∥p(z))
重建项选 BCE/MSE 取决于观测分布假设。 -
GAN:
- 原始最小最大:minGmaxDEx[logD(x)]+Ez[log(1−D(G(z)))]\min_G \max_D \mathbb{E}_{x}[\log D(x)] + \mathbb{E}_{z}[\log(1-D(G(z)))]minGmaxDEx[logD(x)]+Ez[log(1−D(G(z)))]
- 非饱和/hinge/WGAN-GP(更稳的梯度与收敛)
-
扩散模型:训练目标多为噪声 MSE(ϵ\epsilonϵ-预测),或 v-预测的 MSE。
7. 不确定性与校准
7.1 异方差回归(学 σ(x)\sigma(x)σ(x))
L=(y−μ(x))22σ(x)2+12logσ(x)2 \mathcal{L}=\frac{(y-\mu(x))^2}{2\sigma(x)^2}+\frac{1}{2}\log \sigma(x)^2 L=2σ(x)2(y−μ(x))2+21logσ(x)2
模型同时输出 μ,logσ2\mu,\log\sigma^2μ,logσ2,可自适应调权。
7.2 多任务不确定性加权(Kendall & Gal)
两任务示例(回归/分类):
L=12σ12L1+1σ22L2+logσ1+logσ2 \mathcal{L}=\frac{1}{2\sigma_1^2}\mathcal{L}_1 + \frac{1}{\sigma_2^2}\mathcal{L}_2 + \log\sigma_1 + \log\sigma_2 L=2σ121L1+σ221L2+logσ1+logσ2
σi\sigma_iσi 由网络学习,自动平衡尺度。
8. 数值稳定与工程技巧
- 优先用“带 logits 的 API”:
BCEWithLogitsLoss/CrossEntropyLoss
- Log-Sum-Exp 与 clamp 避免 log(0)\log(0)log(0):加 ϵ∈[1e − 12,1e − 7]\epsilon\in[1e\!-\!12,1e\!-\!7]ϵ∈[1e−12,1e−7]
- 类别不平衡:类权重 / Focal / 重采样(按“有效样本数”赋权)
- 归一化与目标尺度:回归目标标准化,学习率更好调
- 梯度裁剪、混合精度:防爆炸、提吞吐
- 标签增强:Label Smoothing / Mixup / CutMix(概率层面的正则)
9. 选型速查(Cheat Sheet)
任务 | 首选损失 | 备选/增强 | 备注 |
---|---|---|---|
标准回归 | MSE | Huber/MAE | 有离群点→Huber |
二分类 | BCEWithLogits | Focal, 类权重 | 长尾→Focal(γ≈2) |
多分类 | CrossEntropy | Label Smoothing | 校准更好 |
多标签 | BCEWithLogits | Focal + α | 独立伯努利 |
语义分割 | BCE + Dice | Tversky/Focal Tversky | 小目标友好 |
检测框回归 | Smooth L1 | G/DIoU/CIoU | 收敛与定位稳定 |
排序 | RankNet/Pairwise Hinge | LambdaRank | 看评估指标(NDCG) |
度量/检索 | Triplet/InfoNCE | Contrastive/N-Pairs | 配合样本挖掘 |
计数回归 | Poisson/NB | — | 过度离散→NB |
蒸馏 | KL(T) | CE + 温度 | 记得乘 T2T^2T2 |
多任务 | 加权求和 | 不确定性加权 | 量纲不同需调权 |
10. 计算例子(手算一遍)
多分类交叉熵:logits z=[2,0,−1]z=[2,0,-1]z=[2,0,−1],真类为 0
Softmax:
maxz=2\max z=2maxz=2;分子 e2−2=1e^{2-2}=1e2−2=1,分母 1+e−2+e−3≈1+0.1353+0.0498=1.18511+e^{-2}+e^{-3}\approx 1+0.1353+0.0498=1.18511+e−2+e−3≈1+0.1353+0.0498=1.1851
p0=1/1.1851≈0.8437p_0=1/1.1851\approx 0.8437p0=1/1.1851≈0.8437
CE =−log0.8437≈0.1697=-\log 0.8437\approx 0.1697=−log0.8437≈0.1697
11. PyTorch 稳定实现合集(可直接用)
说明:优先用官方
*WithLogits
/CrossEntropyLoss
。下列实现覆盖工程常用“坑点”。
# losses_zoo.py
import torch
import torch.nn as F
from torch import nn
# --------- 基础稳定封装 ---------
class BCEWithLogitsStable(nn.Module):
def __init__(self, pos_weight=None, reduction="mean"):
super().__init__()
self.pos_weight = pos_weight
self.reduction = reduction
def forward(self, logits, targets):
return F.binary_cross_entropy_with_logits(
logits, targets, pos_weight=self.pos_weight, reduction=self.reduction
)
class CrossEntropyStable(nn.Module):
def __init__(self, label_smoothing=0.0, weight=None):
super().__init__()
self.label_smoothing = label_smoothing
self.weight = weight
def forward(self, logits, targets):
return F.cross_entropy(
logits, targets, label_smoothing=self.label_smoothing, weight=self.weight
)
# --------- Focal(BCE 版本,多标签/二分类) ---------
class FocalLossBCE(nn.Module):
def __init__(self, gamma=2.0, alpha=None, reduction="mean"):
super().__init__()
self.gamma = gamma
self.alpha = alpha # float or tensor per-class
self.reduction = reduction
def forward(self, logits, targets):
# logits, targets: (N, C) or (N,)
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
p = torch.sigmoid(logits)
pt = targets * p + (1 - targets) * (1 - p)
loss = (1 - pt).pow(self.gamma) * bce
if self.alpha is not None:
alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
loss = alpha_t * loss
return loss.mean() if self.reduction == "mean" else loss.sum()
# --------- Focal(Softmax 版本,多分类) ---------
class FocalLossSoftmax(nn.Module):
def __init__(self, gamma=2.0, weight=None, reduction="mean"):
super().__init__()
self.gamma = gamma
self.weight = weight
self.reduction = reduction
def forward(self, logits, targets):
# targets: (N,) int64
logp = F.log_softmax(logits, dim=-1)
p = logp.exp()
nll = F.nll_loss(logp, targets, weight=self.weight, reduction='none')
pt = p.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
loss = ((1 - pt).pow(self.gamma)) * nll
return loss.mean() if self.reduction == "mean" else loss.sum()
# --------- Dice / Tversky / Focal Tversky(分割) ---------
def _flatten_probs_targets(logits, targets):
# 支持 (N,1,H,W) 或 (N,C,H,W) 的二类情况(C==1 视为二类)
if logits.shape[1] == 1:
probs = torch.sigmoid(logits)
else:
probs = torch.softmax(logits, dim=1)[:,1:2] # 取前景通道
return probs.contiguous().view(logits.size(0), -1), targets.contiguous().view(logits.size(0), -1)
class DiceLoss(nn.Module):
def __init__(self, eps=1e-7):
super().__init__()
self.eps = eps
def forward(self, logits, targets):
probs, targets = _flatten_probs_targets(logits, targets)
inter = (probs * targets).sum(dim=1)
den = probs.sum(dim=1) + targets.sum(dim=1) + self.eps
dice = 1 - (2 * inter / den)
return dice.mean()
class TverskyLoss(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, eps=1e-7):
super().__init__()
self.alpha, self.beta, self.eps = alpha, beta, eps
def forward(self, logits, targets):
probs, targets = _flatten_probs_targets(logits, targets)
tp = (probs * targets).sum(dim=1)
fp = (probs * (1 - targets)).sum(dim=1)
fn = ((1 - probs) * targets).sum(dim=1)
tv = tp / (tp + self.alpha*fp + self.beta*fn + self.eps)
return (1 - tv).mean()
class FocalTverskyLoss(nn.Module):
def __init__(self, alpha=0.7, beta=0.3, gamma=1.33, eps=1e-7):
super().__init__()
self.tversky = TverskyLoss(alpha, beta, eps)
self.gamma = gamma
def forward(self, logits, targets):
tv = 1 - self.tversky(logits, targets) # 这里 tversky() 返回的是 (1 - TV)
return tv.pow(self.gamma)
# --------- BCE + Dice 组合 ---------
class BCEDiceLoss(nn.Module):
def __init__(self, dice_weight=1.0):
super().__init__()
self.bce = nn.BCEWithLogitsLoss()
self.dice = DiceLoss()
self.dw = dice_weight
def forward(self, logits, targets):
return self.bce(logits, targets) + self.dw * self.dice(logits, targets)
# --------- IoU/GIoU/DIoU/CIoU(简洁版,xyxy) ---------
def bbox_iou(box1, box2, eps=1e-7):
# box: (...,4) [x1,y1,x2,y2]
x1 = torch.max(box1[...,0], box2[...,0])
y1 = torch.max(box1[...,1], box2[...,1])
x2 = torch.min(box1[...,2], box2[...,2])
y2 = torch.min(box1[...,3], box2[...,3])
inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
area1 = (box1[...,2]-box1[...,0]).clamp(min=0) * (box1[...,3]-box1[...,1]).clamp(min=0)
area2 = (box2[...,2]-box2[...,0]).clamp(min=0) * (box2[...,3]-box2[...,1]).clamp(min=0)
union = area1 + area2 - inter + eps
return inter / union
def giou_loss(box1, box2):
iou = bbox_iou(box1, box2)
x1 = torch.min(box1[...,0], box2[...,0])
y1 = torch.min(box1[...,1], box2[...,1])
x2 = torch.max(box1[...,2], box2[...,2])
y2 = torch.max(box1[...,3], box2[...,3])
c_area = (x2 - x1) * (y2 - y1) + 1e-7
inter_area = iou * ( (box1[...,2]-box1[...,0])*(box1[...,3]-box1[...,1]) + (box2[...,2]-box2[...,0])*(box2[...,3]-box2[...,1]) - 0 )
union_area = inter_area / (iou + 1e-7)
giou = iou - (c_area - union_area) / c_area
return (1 - giou).mean()
def ciou_loss(box1, box2, eps=1e-7):
iou = bbox_iou(box1, box2, eps)
# center distance
c1x = (box1[...,0] + box1[...,2]) / 2; c1y = (box1[...,1] + box1[...,3]) / 2
c2x = (box2[...,0] + box2[...,2]) / 2; c2y = (box2[...,1] + box2[...,3]) / 2
cw = (torch.max(box1[...,2], box2[...,2]) - torch.min(box1[...,0], box2[...,0]))
ch = (torch.max(box1[...,3], box2[...,3]) - torch.min(box1[...,1], box2[...,1]))
c2 = cw**2 + ch**2 + eps
rho2 = (c1x - c2x)**2 + (c1y - c2y)**2
# aspect ratio
w1 = (box1[...,2]-box1[...,0]).clamp(min=eps); h1 = (box1[...,3]-box1[...,1]).clamp(min=eps)
w2 = (box2[...,2]-box2[...,0]).clamp(min=eps); h2 = (box2[...,3]-box2[...,1]).clamp(min=eps)
v = (4 / (3.1415926**2)) * (torch.atan(w2/h2) - torch.atan(w1/h1))**2
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
ciou = iou - (rho2 / c2 + alpha * v)
return (1 - ciou).mean()
# --------- Triplet / InfoNCE ---------
class TripletLoss(nn.Module):
def __init__(self, margin=0.2, p=2):
super().__init__()
self.margin, self.p = margin, p
def forward(self, a, p, n):
d_ap = F.pairwise_distance(a, p, p=self.p)
d_an = F.pairwise_distance(a, n, p=self.p)
return F.relu(d_ap - d_an + self.margin).mean()
class InfoNCELoss(nn.Module):
def __init__(self, temperature=0.07):
super().__init__()
self.t = temperature
def forward(self, q, k_pos, k_all):
# q, k_pos: (N, D); k_all: (N+M, D) or (K, D)
q = F.normalize(q, dim=-1)
k_pos = F.normalize(k_pos, dim=-1)
k_all = F.normalize(k_all, dim=-1)
pos = torch.sum(q * k_pos, dim=-1, keepdim=True) / self.t
neg = q @ k_all.t() / self.t
logits = torch.cat([pos, neg], dim=1)
labels = torch.zeros(q.size(0), dtype=torch.long, device=q.device)
return F.cross_entropy(logits, labels)
# --------- 异方差回归 NLL ---------
class GaussianNLLHetero(nn.Module):
def forward(self, mu, log_var, y):
# log_var unconstrained; var = exp(log_var)
return 0.5 * (torch.exp(-log_var) * (y - mu)**2 + log_var).mean()
# --------- 多任务不确定性加权 ---------
class MultiTaskUncertaintyLoss(nn.Module):
def __init__(self, task_num):
super().__init__()
self.log_vars = nn.Parameter(torch.zeros(task_num))
def forward(self, losses):
# losses: List[Tensor],每项是单个任务 loss.mean()
total = 0.
for i, li in enumerate(losses):
inv = torch.exp(-self.log_vars[i])
total += inv * li + self.log_vars[i]
return total
通用训练循环示例
def train_epoch(model, loader, criterion, optimizer, device="cuda"):
model.train(); total = 0.0
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
total += loss.item() * x.size(0)
return total / len(loader.dataset)
12. 诊断清单(落地很关键)
- 先跑恒定输入/小 batch:检查损失是否有限、能下降
- 看分解项:组合损失要分别
log
(如 BCE 与 Dice) - 类不平衡:看每类召回/PR 曲线,而非只看总 loss
- 学习率/归一化:回归目标标准化常常立竿见影
- 梯度/权重分布:出现 NaN → 降 LR、裁剪、检查 logits API