知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化
flyfish
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
目标函数(Objective Function) 是衡量模型预测结果与真实结果之间差异的函数,其核心作用是为模型的参数优化提供 “指导信号”—— 通过最小化(或最大化)目标函数的值,让模型逐渐学习到更优的参数,从而提升预测性能。
深度学习模型的训练本质是 “参数优化”:模型通过输入数据生成预测结果后,目标函数会计算预测值与真实标签(或期望输出)的 “差距”,得到一个量化的 “损失值”。这个损失值越大,说明模型当前的预测效果越差;反之则越好。
优化算法(如梯度下降)会基于目标函数的梯度信息,调整模型参数(如权重、偏置),最终使目标函数的值达到最小(或最大,视任务而定),此时模型的预测结果与真实结果最接近。
其他名字
损失函数(Loss Function):很多时候人们会直接说 “损失函数” 来指代目标函数,尤其在单样本或简化场景中。
代价函数(Cost Function):和损失函数类似,在不少教材或论文里,这两个词和目标函数几乎同义,只是 “代价” 更强调 “为错误付出的成本”。
优化目标(Optimization Objective):因为目标函数的核心是被模型 “优化” 的对象(比如最小化它),所以也常被称为 “优化目标”。
准则函数(Criterion Function):“准则” 就是 “判断标准”,这个词更强调它是衡量模型好坏的 “标准”,和目标函数含义一致。
这些名字都是用来指导模型优化的函数,只是在不同语境里习惯用不同的词,不用太纠结,知道它们说的是一回事儿就行
演化链
- SFT / MLE(one-hot“老师”)
→ 2) 监督 KD(固定数据 + 前向 KL,对齐软标签)
→ 3) On-Policy KD(在学生自己会走到的状态上学,消除分布失配)
→ 4) 换散度:JSD(β)/反向 KL(容量受限时更模式寻求,少幻觉)
→ 5) GKD( λ \lambda λ 控制样本来源, D D D 自由可选,统一一切)
→ 6) RL + On-Policy GKD(奖励最大化 + 蒸馏约束,同训同收)。
0. 基础:自回归分解与“逐 token”散度
自回归 LM 满足 p ( y ∣ x ) = ∏ n = 1 L y p ( y n ∣ y < n , x ) p(y|x)=\prod_{n=1}^{L_y}p(y_n|y_{<n},x) p(y∣x)=∏n=1Lyp(yn∣y<n,x)。因此一切“序列级”的目标,都能拆成逐 token的和。论文把“老师 p T p_T pT”与“学生 p S p_S pS”在序列 y y y 上的分布差异定义为(式 (2)):
D ( p T ∥ p S θ ) ( y ∣ x ) = 1 L y ∑ n = 1 L y D ( p T ( ⋅ ∣ y < n , x ) ∥ p S θ ( ⋅ ∣ y < n , x ) ) . D(p_T\!\parallel p_S^\theta)(y|x)=\frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot|y_{<n},x)\parallel p_S^\theta(\cdot|y_{<n},x)\right). D(pT∥pSθ)(y∣x)=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x)).
这一步把任何“散度 D D D”都落到了token 级,后续所有方法只需选择:用什么 D D D,在什么数据上评价它。
1. 监督式微调(SFT):极大似然是前向 KL 的特例
只有人工标注 ( X , Y ) (X,Y) (X,Y) 而没有老师时,最简单是最小化负对数似然:
L SFT ( θ ) = E ( x , y ) ∼ ( X , Y ) [ − log p S θ ( y ∣ x ) ] . L_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)}[-\log p_S^\theta(y|x)]. LSFT(θ)=E(x,y)∼(X,Y)[−logpSθ(y∣x)].
它等价于把每个目标 token 看成 one-hot“老师”,即在式 (2) 中令 D = D K L D= D_{\mathrm{KL}} D=DKL 且 p T p_T pT 为真分布的“δ-分布”。
梯度形态(单步 token)
对学生 softmax 的对数几率
z
z
z 来说,前向 KL / 交叉熵的梯度是
∇
z
=
p
S
−
p
T
\nabla_z = p_S - p_T
∇z=pS−pT。这说明 SFT/前向 KL 本质是让学生概率向目标分布对齐。
局限:只在固定数据上学,推理时学生会走到自己没见过的前缀状态,产生训练-推理分布失配(exposure bias)。
2. 监督式 KD(Supervised KD):用“软标签”的前向 KL
有了老师 p T p_T pT(可给每个 token 的全分布),就把 SFT 的 one-hot 换成老师分布,得到(式 (3)):
L SD ( θ ) = E ( x , y ) ∼ ( X , Y ) [ D K L ( p T ∥ p S θ ) ( y ∣ x ) ] . L_{\text{SD}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)} \Big[D_{\mathrm{KL}}\!\big(p_T\parallel p_S^\theta\big)(y|x)\Big]. LSD(θ)=E(x,y)∼(X,Y)[DKL(pT∥pSθ)(y∣x)].
好处:利用“软标签”提供的暗知识(非目标 token 的相对概率)。梯度仍是 p S − p T p_S-p_T pS−pT 的形态,但 p T p_T pT 不再是 one-hot。
局限:仍在固定序列上训练(可能是人工真值或老师生成的序列),分布失配依旧。
3. On-Policy KD:让学生在自己生成的序列上学
为解决分布失配,论文把“期望”从固定 ( X , Y ) (X,Y) (X,Y) 换成学生策略下的输出序列,得到(式 (4)):
L OD ( θ ) = E x ∼ X [ E y ∼ p S ( ⋅ ∣ x ) [ D K L ( p T ∥ p S θ ) ( y ∣ x ) ] ] . L_{\text{OD}}(\theta)=\mathbb{E}_{x\sim X}\Big[\mathbb{E}_{y\sim p_S(\cdot|x)} \big[D_{\mathrm{KL}}(p_T\parallel p_S^\theta)(y|x)\big]\Big]. LOD(θ)=Ex∼X[Ey∼pS(⋅∣x)[DKL(pT∥pSθ)(y∣x)]].
关键实现细节:不对学生的采样分布反传(只在内层 KL 里更新 θ \theta θ),可避免 REINFORCE 式高方差,训练更稳定高效。直观地说:学生先按当前策略走一遍,把“走错的地方”交给老师打分,再按 KL 梯度把这些状态上的 logits 拉回去。
4. 选择更合适的“散度” D D D:从 KL 到广义 JSD
前向 KL 要求学生“覆盖老师的全部支持集”,容量不足时会把概率“摊薄”到老师几乎不选的 token 上,易造成幻觉;反向 KL 则更“择众”,只贴老师高概率 token,减少“离谱”但可能牺牲多样性。论文采用广义 JSD 在两者间连续插值(式 (1)):
D J S D ( β ) ( P ∥ Q ) = β D K L ( P ∥ β P + ( 1 − β ) Q ) + ( 1 − β ) D K L ( Q ∥ β P + ( 1 − β ) Q ) . D_{\mathrm{JSD}(\beta)}(P\parallel Q)= \beta\, D_{\mathrm{KL}}\!\big(P\parallel \beta P+(1-\beta)Q\big)+ (1-\beta)\, D_{\mathrm{KL}}\!\big(Q\parallel \beta P+(1-\beta)Q\big). DJSD(β)(P∥Q)=βDKL(P∥βP+(1−β)Q)+(1−β)DKL(Q∥βP+(1−β)Q).
当 β → 0 \beta\to 0 β→0 时, 1 β D J S D ( β ) → D K L ( P ∥ Q ) \tfrac{1}{\beta}D_{\mathrm{JSD}(\beta)} \to D_{\mathrm{KL}}(P\parallel Q) β1DJSD(β)→DKL(P∥Q); β → 1 \beta\to 1 β→1 时更接近反向 KL 的行为。这样就能按任务、温度、容量调节“覆盖 vs. 模式寻求”的折衷。
经验指引:不同任务/采样温度的最优散度不同;且很多实验里,**纯 on-policy(学生样本占比 100%)**的效果最好。
5. GKD:统一“在哪些序列上学”和“用什么散度学”
把“数据来源”(固定数据 vs 学生自采样)与“散度种类”(前向/反向 KL、JSD(β)…)统一起来,得到广义 KD(式 (GKD)):
L GKD ( θ ) = ( 1 − λ ) E ( x , y ) ∼ ( X , Y ) [ D ( p T ∥ p S θ ) ( y ∣ x ) ] + λ E x ∼ X E y ∼ p S ( ⋅ ∣ x ) [ D ( p T ∥ p S θ ) ( y ∣ x ) ] . \!\!\!\!L_{\text{GKD}}(\theta) =(1-\lambda)\,\mathbb{E}_{(x,y)\sim(X,Y)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big] +\lambda\,\mathbb{E}_{x\sim X}\mathbb{E}_{y\sim p_S(\cdot|x)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big]. LGKD(θ)=(1−λ)E(x,y)∼(X,Y)[D(pT∥pSθ)(y∣x)]+λEx∼XEy∼pS(⋅∣x)[D(pT∥pSθ)(y∣x)].
- λ = 0 \lambda=0 λ=0:退化为监督 KD; λ = 1 \lambda=1 λ=1:纯 on-policy KD;
-
D
D
D 可取前向/反向 KL 或 JSD(
β
\beta
β)。
实现上,对采样过程仍不反传。论文还给了Algorithm 1:每步按 λ \lambda λ 抛硬币决定用哪类数据,再最小化该 batch 的散度。
一眼看懂的梯度形态(抽象)
忽略采样反传后,
∇ θ L GKD = E [ 1 L y ∑ n ∇ θ D ( p T ( ⋅ ) ∥ p S θ ( ⋅ ) ) ⏟ 如前向 KL 时 ∝ p S − p T ] , \nabla_\theta L_{\text{GKD}} = \mathbb{E}\!\left[\frac{1}{L_y}\sum_n \underbrace{\nabla_\theta D\!\big(p_T(\cdot)\parallel p_S^\theta(\cdot)\big)}_{\text{如前向 KL 时 }\;\propto\;p_S-p_T}\right], ∇θLGKD=E Ly1n∑如前向 KL 时 ∝pS−pT ∇θD(pT(⋅)∥pSθ(⋅)) ,
唯一差别在于这个期望是对哪种序列分布取(由 λ \lambda λ 和是否 on-policy 决定),以及** D D D** 的具体形式。
6. 再进一步:把 RL 目标和 On-Policy GKD 并列优化
很多真实目标(如事实一致性)是不可导/非似然的。论文把策略梯度的奖励项与蒸馏正则项并列,给出(式 (5)):
E x ∼ X [ ( 1 − α ) E y ∼ p S θ [ r ( y ) ] − α E y ∼ p S ( ⋅ ∣ x ) D ( p T ∥ p S θ ) ( y ∣ x ) ] . \mathbb{E}_{x\sim X}\Big[(1-\alpha)\,\mathbb{E}_{y\sim p_S^\theta}[r(y)] -\alpha\,\mathbb{E}_{y\sim p_S(\cdot|x)}D(p_T\parallel p_S^\theta)(y|x)\Big]. Ex∼X[(1−α)Ey∼pSθ[r(y)]−αEy∼pS(⋅∣x)D(pT∥pSθ)(y∣x)].
- 第一项:标准 REINFORCE/策略梯度的RL 目标;
- 第二项:on-policy 蒸馏正则,把策略往老师靠,以防 RL 走偏;
- α \alpha α 控制 RL 与蒸馏的权衡( α = 1 \alpha=1 α=1 退化为仅蒸馏)。这与 RLHF 里常见的“KL 正则”相似,但这里是向老师而不是向初始策略收缩。