知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化

知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化

flyfish

代码实践

On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

目标函数(Objective Function) 是衡量模型预测结果与真实结果之间差异的函数,其核心作用是为模型的参数优化提供 “指导信号”—— 通过最小化(或最大化)目标函数的值,让模型逐渐学习到更优的参数,从而提升预测性能。

深度学习模型的训练本质是 “参数优化”:模型通过输入数据生成预测结果后,目标函数会计算预测值与真实标签(或期望输出)的 “差距”,得到一个量化的 “损失值”。这个损失值越大,说明模型当前的预测效果越差;反之则越好。
优化算法(如梯度下降)会基于目标函数的梯度信息,调整模型参数(如权重、偏置),最终使目标函数的值达到最小(或最大,视任务而定),此时模型的预测结果与真实结果最接近。

其他名字

损失函数(Loss Function):很多时候人们会直接说 “损失函数” 来指代目标函数,尤其在单样本或简化场景中。
代价函数(Cost Function):和损失函数类似,在不少教材或论文里,这两个词和目标函数几乎同义,只是 “代价” 更强调 “为错误付出的成本”。
优化目标(Optimization Objective):因为目标函数的核心是被模型 “优化” 的对象(比如最小化它),所以也常被称为 “优化目标”。
准则函数(Criterion Function):“准则” 就是 “判断标准”,这个词更强调它是衡量模型好坏的 “标准”,和目标函数含义一致。
这些名字都是用来指导模型优化的函数,只是在不同语境里习惯用不同的词,不用太纠结,知道它们说的是一回事儿就行

演化链

  1. SFT / MLE(one-hot“老师”)
    → 2) 监督 KD(固定数据 + 前向 KL,对齐软标签)
    → 3) On-Policy KD(在学生自己会走到的状态上学,消除分布失配)
    → 4) 换散度:JSD(β)/反向 KL(容量受限时更模式寻求,少幻觉)
    → 5) GKD λ \lambda λ 控制样本来源, D D D 自由可选,统一一切)
    → 6) RL + On-Policy GKD(奖励最大化 + 蒸馏约束,同训同收)。

0. 基础:自回归分解与“逐 token”散度

自回归 LM 满足 p ( y ∣ x ) = ∏ n = 1 L y p ( y n ∣ y < n , x ) p(y|x)=\prod_{n=1}^{L_y}p(y_n|y_{<n},x) p(yx)=n=1Lyp(yny<n,x)。因此一切“序列级”的目标,都能拆成逐 token的和。论文把“老师 p T p_T pT”与“学生 p S p_S pS”在序列 y y y 上的分布差异定义为(式 (2)):

D ( p T  ⁣ ∥ p S θ ) ( y ∣ x ) = 1 L y ∑ n = 1 L y D  ⁣ ( p T ( ⋅ ∣ y < n , x ) ∥ p S θ ( ⋅ ∣ y < n , x ) ) . D(p_T\!\parallel p_S^\theta)(y|x)=\frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot|y_{<n},x)\parallel p_S^\theta(\cdot|y_{<n},x)\right). D(pTpSθ)(yx)=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x)).

这一步把任何“散度 D D D”都落到了token 级,后续所有方法只需选择:用什么 D D D,在什么数据上评价它。

1. 监督式微调(SFT):极大似然是前向 KL 的特例

只有人工标注 ( X , Y ) (X,Y) (X,Y)没有老师时,最简单是最小化负对数似然:

L SFT ( θ ) = E ( x , y ) ∼ ( X , Y ) [ − log ⁡ p S θ ( y ∣ x ) ] . L_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)}[-\log p_S^\theta(y|x)]. LSFT(θ)=E(x,y)(X,Y)[logpSθ(yx)].

它等价于把每个目标 token 看成 one-hot“老师”,即在式 (2) 中令 D = D K L D= D_{\mathrm{KL}} D=DKL p T p_T pT 为真分布的“δ-分布”。

梯度形态(单步 token)
对学生 softmax 的对数几率 z z z 来说,前向 KL / 交叉熵的梯度是 ∇ z = p S − p T \nabla_z = p_S - p_T z=pSpT。这说明 SFT/前向 KL 本质是让学生概率向目标分布对齐

局限:只在固定数据上学,推理时学生会走到自己没见过的前缀状态,产生训练-推理分布失配(exposure bias)。

2. 监督式 KD(Supervised KD):用“软标签”的前向 KL

有了老师 p T p_T pT(可给每个 token 的全分布),就把 SFT 的 one-hot 换成老师分布,得到(式 (3)):

L SD ( θ ) = E ( x , y ) ∼ ( X , Y ) [ D K L  ⁣ ( p T ∥ p S θ ) ( y ∣ x ) ] . L_{\text{SD}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)} \Big[D_{\mathrm{KL}}\!\big(p_T\parallel p_S^\theta\big)(y|x)\Big]. LSD(θ)=E(x,y)(X,Y)[DKL(pTpSθ)(yx)].

好处:利用“软标签”提供的暗知识(非目标 token 的相对概率)。梯度仍是 p S − p T p_S-p_T pSpT 的形态,但 p T p_T pT 不再是 one-hot。

局限:仍在固定序列上训练(可能是人工真值或老师生成的序列),分布失配依旧。

3. On-Policy KD:让学生在自己生成的序列上学

为解决分布失配,论文把“期望”从固定 ( X , Y ) (X,Y) (X,Y) 换成学生策略下的输出序列,得到(式 (4)):

L OD ( θ ) = E x ∼ X [ E y ∼ p S ( ⋅ ∣ x ) [ D K L ( p T ∥ p S θ ) ( y ∣ x ) ] ] . L_{\text{OD}}(\theta)=\mathbb{E}_{x\sim X}\Big[\mathbb{E}_{y\sim p_S(\cdot|x)} \big[D_{\mathrm{KL}}(p_T\parallel p_S^\theta)(y|x)\big]\Big]. LOD(θ)=ExX[EypS(x)[DKL(pTpSθ)(yx)]].

关键实现细节不对学生的采样分布反传(只在内层 KL 里更新 θ \theta θ),可避免 REINFORCE 式高方差,训练更稳定高效。直观地说:学生先按当前策略走一遍,把“走错的地方”交给老师打分,再按 KL 梯度把这些状态上的 logits 拉回去。

4. 选择更合适的“散度” D D D:从 KL 到广义 JSD

前向 KL 要求学生“覆盖老师的全部支持集”,容量不足时会把概率“摊薄”到老师几乎不选的 token 上,易造成幻觉;反向 KL 则更“择众”,只贴老师高概率 token,减少“离谱”但可能牺牲多样性。论文采用广义 JSD 在两者间连续插值(式 (1)):

D J S D ( β ) ( P ∥ Q ) = β   D K L  ⁣ ( P ∥ β P + ( 1 − β ) Q ) + ( 1 − β )   D K L  ⁣ ( Q ∥ β P + ( 1 − β ) Q ) . D_{\mathrm{JSD}(\beta)}(P\parallel Q)= \beta\, D_{\mathrm{KL}}\!\big(P\parallel \beta P+(1-\beta)Q\big)+ (1-\beta)\, D_{\mathrm{KL}}\!\big(Q\parallel \beta P+(1-\beta)Q\big). DJSD(β)(PQ)=βDKL(PβP+(1β)Q)+(1β)DKL(QβP+(1β)Q).

β → 0 \beta\to 0 β0 时, 1 β D J S D ( β ) → D K L ( P ∥ Q ) \tfrac{1}{\beta}D_{\mathrm{JSD}(\beta)} \to D_{\mathrm{KL}}(P\parallel Q) β1DJSD(β)DKL(PQ) β → 1 \beta\to 1 β1 时更接近反向 KL 的行为。这样就能按任务、温度、容量调节“覆盖 vs. 模式寻求”的折衷。

经验指引:不同任务/采样温度的最优散度不同;且很多实验里,**纯 on-policy(学生样本占比 100%)**的效果最好。

5. GKD:统一“在哪些序列上学”和“用什么散度学”

把“数据来源”(固定数据 vs 学生自采样)与“散度种类”(前向/反向 KL、JSD(β)…)统一起来,得到广义 KD(式 (GKD)):

 ⁣ ⁣ ⁣ ⁣ L GKD ( θ ) = ( 1 − λ )   E ( x , y ) ∼ ( X , Y )  ⁣ ⁣ [ D ( p T ∥ p S θ ) ( y ∣ x ) ] + λ   E x ∼ X E y ∼ p S ( ⋅ ∣ x )  ⁣ ⁣ [ D ( p T ∥ p S θ ) ( y ∣ x ) ] . \!\!\!\!L_{\text{GKD}}(\theta) =(1-\lambda)\,\mathbb{E}_{(x,y)\sim(X,Y)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big] +\lambda\,\mathbb{E}_{x\sim X}\mathbb{E}_{y\sim p_S(\cdot|x)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big]. LGKD(θ)=(1λ)E(x,y)(X,Y)[D(pTpSθ)(yx)]+λExXEypS(x)[D(pTpSθ)(yx)].

  • λ = 0 \lambda=0 λ=0:退化为监督 KD λ = 1 \lambda=1 λ=1纯 on-policy KD
  • D D D 可取前向/反向 KL 或 JSD( β \beta β)
    实现上,对采样过程仍不反传。论文还给了Algorithm 1:每步按 λ \lambda λ 抛硬币决定用哪类数据,再最小化该 batch 的散度。

一眼看懂的梯度形态(抽象)
忽略采样反传后,

∇ θ L GKD = E  ⁣ [ 1 L y ∑ n ∇ θ D  ⁣ ( p T ( ⋅ ) ∥ p S θ ( ⋅ ) ) ⏟ 如前向 KL 时     ∝    p S − p T ] , \nabla_\theta L_{\text{GKD}} = \mathbb{E}\!\left[\frac{1}{L_y}\sum_n \underbrace{\nabla_\theta D\!\big(p_T(\cdot)\parallel p_S^\theta(\cdot)\big)}_{\text{如前向 KL 时 }\;\propto\;p_S-p_T}\right], θLGKD=E Ly1n如前向 KL  pSpT θD(pT()pSθ()) ,

唯一差别在于这个期望是对哪种序列分布取(由 λ \lambda λ 和是否 on-policy 决定),以及** D D D** 的具体形式。

6. 再进一步:把 RL 目标和 On-Policy GKD 并列优化

很多真实目标(如事实一致性)是不可导/非似然的。论文把策略梯度的奖励项蒸馏正则项并列,给出(式 (5)):

E x ∼ X [ ( 1 − α )   E y ∼ p S θ [ r ( y ) ] − α   E y ∼ p S ( ⋅ ∣ x ) D ( p T ∥ p S θ ) ( y ∣ x ) ] . \mathbb{E}_{x\sim X}\Big[(1-\alpha)\,\mathbb{E}_{y\sim p_S^\theta}[r(y)] -\alpha\,\mathbb{E}_{y\sim p_S(\cdot|x)}D(p_T\parallel p_S^\theta)(y|x)\Big]. ExX[(1α)EypSθ[r(y)]αEypS(x)D(pTpSθ)(yx)].

  • 第一项:标准 REINFORCE/策略梯度的RL 目标
  • 第二项:on-policy 蒸馏正则,把策略往老师靠,以防 RL 走偏;
  • α \alpha α 控制 RL 与蒸馏的权衡 α = 1 \alpha=1 α=1 退化为仅蒸馏)。这与 RLHF 里常见的“KL 正则”相似,但这里是向老师而不是向初始策略收缩。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值