1. 模型蒸馏的三种形式
(1) 软标签蒸馏(Soft Labels Distillation)
-
核心思想:
通过教师模型生成的类别概率分布(软标签)作为监督信号,让学生模型学习类别间的相对关系(如"猫 vs 狗"的相似性高于"猫 vs 汽车")。- 核心优势:
- 保留教师模型的泛化能力(Dark Knowledge)。
- 适用于细粒度分类等需要类别间关系信息的任务。
- 核心优势:
-
核心公式(完整对数展开):
Lsoft=T2⋅KL(σ(zt/T)∥σ(zs/T))L_{soft}=T^2⋅KL(σ(z_t/T)∥σ(z_s/T))Lsoft=T2⋅KL(σ(zt/T)∥σ(zs/T))
-
KL散度展开:
KL(P∥Q)=∑iP(i)logP(i)Q(i)=∑iP(i)logP(i)−∑iP(i)logQ(i)KL(P \parallel Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)} = \sum_{i} P(i) \log P(i) - \sum_{i} P(i) \log Q(i)KL(P∥Q)=∑iP(i)logQ(i)P(i)=∑iP(i)logP(i)−∑iP(i)logQ(i)
教师熵(常数)对应的是 ∑iP(i)logP(i)\sum_{i} P(i) \log P(i)∑iP(i)logP(i),交叉熵项对应的是 −∑iP(i)logQ(i)-\sum_{i} P(i) \log Q(i)−∑iP(i)logQ(i)。
实际优化时仅需计算交叉熵项:
Lsoft=−T2⋅∑iσ(zt/T)ilogσ(zs/T)iL_{soft}=−T^2⋅∑_iσ(z_t/T)_ilogσ(z_s/T)_iLsoft=−T2⋅∑iσ(zt/T)ilogσ(zs/T)i
-
温度参数 TTT:
-
高温(T↑T↑T↑):概率分布平滑(软标签),如 T=2时:
z=[3.0,1.0]⇒σ(z/2)≈[0.71,0.29]z=[3.0,1.0]⇒σ(z/2)≈[0.71,0.29]z=[3.0,1.0]⇒σ(z/2)≈[0.71,0.29]
-
低温(T↓T↓T↓):趋近one-hot(硬标签),如 T=0.1T=0.1T=0.1 时:
σ(z/0.1)≈[0.99,0.01]σ(z/0.1)≈[0.99,0.01]σ(z/0.1)≈[0.99,0.01]
-
-
梯度补偿:
对学生的logitszs(i)logits z_s^{(i)}logitszs(i) 求偏导时:
$ \frac{∂L_{soft}}{∂z_s^{(i)}} $
当 TTT增大时,梯度幅度会缩小 1T\frac{1}{T}T1 倍,因此需乘以 T2T^2T2 以保持梯度量级稳定。
补充目的:抵消温度TTT对梯度幅度的缩放效应,维持训练稳定性。
-
-
代码实现:
def soft_label_loss(teacher_logits, student_logits, T=2.0): # 教师概率(停止梯度) p_teacher = F.softmax(teacher_logits.detach() / T, dim=-1) # 学生对数概率 log_p_student = F.log_softmax(student_logits / T, dim=-1) # KL散度损失(需手动乘T^2) loss = F.kl_div(log_p_student, p_teacher, reduction='batchmean') * (T ** 2) return loss # 调用示例 teacher_logits = torch.tensor([[3.0, 1.0, 0.5]]) student_logits = torch.tensor([[2.0, 0.8, 0.3]], requires_grad=True) loss = soft_label_loss(teacher_logits, student_logits, T=2.0)
-
注意事项:
- 温度 TTT 需调参(通常 T∈[1,5]T∈[1,5]T∈[1,5])。
- 教师logits需
.detach()
避免错误梯度传播。
(2) 硬标签蒸馏(Hard Labels Distillation)
-
核心思想:
直接使用真实标签(Ground Truth)的one-hot编码作为监督信号,强制学生模型学习与真实标签完全匹配的决策边界。 -
具体实现:
-
标准交叉熵损失:
Lhard=−∑i=1Cytrue(i)logσ(zs(i))L_{hard}=−∑_{i=1}^Cy_{true}^{(i)}logσ(z_s^{(i)})Lhard=−∑i=1Cytrue(i)logσ(zs(i))
- CCC:类别总数。
- ytruey_{true}ytrue:真实标签的one-hot编码(如
[0, 1, 0]
)。 - σ(zs)σ(z_s)σ(zs):学生模型的softmax输出。
-
标签平滑(Label Smoothing):
为避免过拟合,可对硬标签做平滑处理(如将one-hot的1替换为0.9,其余类分配0.1/(C-1)):ysmooth(i)={0.9if i=true class,0.1/(C−1)otherwise.y^{(i)}_{\text{smooth}} = \begin{cases} 0.9 & \text{if } i = \text{true class}, \\ 0.1/(C - 1) & \text{otherwise}. \end{cases}ysmooth(i)={0.90.1/(C−1)if i=true class,otherwise.
-
-
梯度计算:
对学生的logitszs(i)logits z_s^{(i)}logitszs(i)的梯度为:$ \frac{∂L_{hard}}{∂z_s^{(i)}} $
- 当学生预测概率 zs(i)z_s^{(i)}zs(i) 接近真实标签 ytrue(i)y_{true}^{(i)}ytrue(i) 时,梯度趋近0。
-
代码实现:
# 标准硬标签损失 hard_loss = F.cross_entropy(student_logits, true_labels) # 标签平滑版本 smooth_labels = true_labels * 0.9 + 0.1 / num_classes hard_loss_smooth = F.cross_entropy(student_logits, smooth_labels)
-
注意事项:
- 硬标签缺乏类别间关系信息,可能限制模型泛化能力。
- 标签平滑可缓解过confidence问题,但需谨慎设置平滑参数(如0.1)。
(3) 中间层特征蒸馏(Intermediate Features Distillation)
-
核心思想:
让学生模型模仿教师模型的中间层特征表示(如隐藏层输出、注意力权重),而不仅是最终输出。 -
具体实现:
-
隐藏层输出对齐:
Lfeature=∥ϕ(ht)−hs∥22=∑k(ϕ(ht)k−hs(k))2L_{feature}=∥ϕ(h_t)−h_s∥_2^2=∑_k(ϕ(h_t)_k−h_s^{(k)})^2Lfeature=∥ϕ(ht)−hs∥22=∑k(ϕ(ht)k−hs(k))2
-
ht,hsh_t,h_sht,hs:教师和学生的中间层特征。
-
ϕϕϕ:适配层(如线性投影),用于匹配特征维度。
-
梯度计算(适配层为 ϕ(x)=Wx+bϕ(x)=Wx+bϕ(x)=Wx+b时):
∂Lfeature∂hs=2W⊤(ϕ(ht)−hs)\frac{∂L_{feature}}{∂h_s}=2W^⊤(ϕ(h_t)−h_s)∂hs∂Lfeature=2W⊤(ϕ(ht)−hs)
-
-
注意力矩阵对齐(针对Transformer):
Lattn=1L∑l=1L∥At(l)−As(l)∥F2L_{attn}=\frac{1}{L}∑_{l=1}^L∥A_t^{(l)}−A_s^{(l)}∥_F^2Lattn=L1∑l=1L∥At(l)−As(l)∥F2
- At(l),As(l)A_t^{(l)},A_s^{(l)}At(l),As(l):教师和学生第 lll 层的注意力矩阵。
-
-
代码实现(隐藏层对齐):
class IntermediateDistiller(nn.Module): def __init__(self, student_dim, teacher_dim): super().__init__() self.adapter = nn.Linear(student_dim, teacher_dim) # 维度适配 def forward(self, h_s, h_t): adapted_h_s = self.adapter(h_s) return F.mse_loss(adapted_h_s, h_t)
-
注意事项:
- 选择具有高语义信息的中间层(如深层卷积或Transformer后几层)。
- 若同时使用多层级损失,需合理分配权重。
2. 温度参数 TT 的命名与物理类比
- 统计力学背景:
- 玻尔兹曼分布 p(E)∝e−E/(kT)p(E)∝e^{−E/(kT)}p(E)∝e−E/(kT),TTT 控制能量分布的平坦度。
- 高温:粒子能量分布均匀 → 概率分布软。
- 低温:粒子集中基态 → 概率分布尖锐。
- 蒸馏中的对应:
- T↑T↑T↑:教师输出保留更多类别间关系(如“猫”与“狗”的相似性)。
- T↓T↓T↓:教师输出趋近真实标签的one-hot编码。
3. 经典蒸馏的联合损失函数
Ltotal=α⋅Lhard+(1−α)⋅LsoftL_{total}=α⋅L_{hard}+(1−α)⋅L_{soft}Ltotal=α⋅Lhard+(1−α)⋅Lsoft
-
权重 ααα:典型值 α∈[0.1,0.5]α∈[0.1,0.5]α∈[0.1,0.5]。
-
完整展开:
Ltotal=α(−∑iytrue(i)logσ(zs(i)))+(1−α)(−T2∑iσ(zt/T)ilogσ(zs/T)i)L_{total}=α(−∑_iy_{true}^{(i)}logσ(z_s^{(i)}))+(1−α)(−T^2∑_iσ(z_t/T)_ilogσ(z_s/T)_i)Ltotal=α(−∑iytrue(i)logσ(zs(i)))+(1−α)(−T2∑iσ(zt/T)ilogσ(zs/T)i)
4. KL散度方向性与实现细节
-
数学定义:
KL(P∥Q)=∑iP(i)logP(i)Q(i)KL(P∥Q)=∑_iP(i)log\frac{P(i)}{Q(i)}KL(P∥Q)=∑iP(i)logQ(i)P(i)(不可逆)
- 教师分布 P=σ(zt/TP=σ(z_t/TP=σ(zt/T为目标。
- 学生分布 Q=σ(zs/T)Q=σ(z_s/T)Q=σ(zs/T) 为优化对象。
-
PyTorch实现:
# 正确顺序:KL(target || input) loss = F.kl_div( input=F.log_softmax(student_logits / T, dim=-1), # 需log概率 target=F.softmax(teacher_logits / T, dim=-1), # 直接softmax reduction='batchmean' ) * (T ** 2)
5. 关键注意事项(无新增内容)
- 温度 TTT 的选择:
- 过高(如 T=10T=10T=10)导致教师知识模糊,过低(如 T=0.1T=0.1T=0.1)退化为硬标签。
- 硬标签定义:
- 必须使用真实标签 ytruey_{true}ytrue,误用教师伪标签属于伪标签方法。
- 符号 ||:
- 仅在KL散度中表示方向分隔,编程中为逻辑或(Python用
or
)。
- 仅在KL散度中表示方向分隔,编程中为逻辑或(Python用
6. 完整代码示例(与历史问答严格一致)
import torch
import torch.nn.functional as F
def distillation_loss(teacher_logits, student_logits, true_labels, T=2.0, alpha=0.3):
# 软标签损失(KL散度,教师分布为目标)
soft_loss = F.kl_div(
input=F.log_softmax(student_logits / T, dim=-1),
target=F.softmax(teacher_logits / T, dim=-1),
reduction='batchmean'
) * (T ** 2)
# 硬标签损失(交叉熵,真实标签)
hard_loss = F.cross_entropy(student_logits, true_labels)
# 联合损失
total_loss = alpha * hard_loss + (1 - alpha) * soft_loss
return total_loss
# 示例数据
teacher_logits = torch.tensor([[3.0, 1.0, 0.5]])
student_logits = torch.tensor([[2.0, 0.8, 0.3]], requires_grad=True)
true_labels = torch.tensor([0]) # 真实类别索引
loss = distillation_loss(teacher_logits, student_logits, true_labels, T=2.0, alpha=0.3)