这篇文章提出了广义知识蒸馏(Generalized Knowledge Distillation, GKD),旨在解决自回归语言模型(如T5)在知识蒸馏过程中存在的训练-推理分布不匹配问题。以下是文章的主要内容总结:
1. 问题背景
-
**知识蒸馏(KD)**通常用于压缩大模型(教师模型)为小模型(学生模型),以减少推理成本和内存占用。
-
传统的KD方法依赖于固定的输出序列(如教师生成的序列或真实数据),导致学生在推理时生成的序列与训练时看到的序列分布不一致,影响生成质量。
2. 核心贡献
-
广义知识蒸馏(GKD):通过在线策略的方式,使用学生模型自身生成的序列进行蒸馏,教师模型对这些序列提供反馈,从而减少训练与推理的分布差异。
-
灵活性:GKD允许使用不同的散度度量(如前向KL、反向KL、Jensen-Shannon散度等),并根据任务需求选择最优的散度。
-
与强化学习结合:GKD可以与强化学习微调(如RLHF、RLAIF)无缝结合,在优化任务奖励的同时进行蒸馏,提升模型性能。
3. 实验验证
-
任务特定蒸馏:在摘要生成(XSum)、机器翻译(WMT)和算术推理(GSM8K)任务上,GKD显著优于传统的KD方法(如SeqKD、Supervised KD)。
-
任务无关蒸馏:在指令微调任务上,GKD在MMLU和BBH基准测试中表现优异,尤其是在使用反向KL时效果最佳。
-
数据效率:GKD在少量数据下表现良好,甚至在使用5%的数据时,性能优于使用完整数据集的传统方法。
4. 关键发现
-
在线策略数据的重要性:使用学生自身生成的序列进行蒸馏,能够显著提升模型性能,尤其是在推理时生成高质量文本。
-
散度选择的任务依赖性:不同任务需要不同的散度度量,前向KL在贪婪采样下表现良好,而反向KL在指令微调中效果更佳。
-
与RL的结合:GKD与强化学习结合后,能够在优化任务奖励的同时提升模型的事实一致性和生成质量。
5. 未来方向
-
将GKD扩展到其他自回归序列模型,如音频、视频和文本到图像生成模型,以进一步提升生成模型的效率和性能。
GKD通过在线策略的蒸馏方法,有效解决了自回归语言模型在知识蒸馏中的分布不匹配问题,并在多个任务上展示了显著的性能提升。其灵活性和与强化学习的结合,为模型压缩和优化提供了新的思路。这里是自己的论文阅读记录,感兴趣的话可以参考一下,如果需要阅读原文的话可以看这里,如下所示:
摘要:
知识蒸馏(KD)广泛用于压缩教师模型,以减少其推理成本和内存占用,通过训练一个较小的学生模型来实现。然而,当前用于自回归序列模型的KD方法在训练期间看到的输出序列与学生推理期间生成的序列之间存在分布不匹配的问题。为了解决这个问题,我们引入了广义知识蒸馏(GKD)。GKD不依赖于固定的输出序列集,而是通过利用教师模型对学生自我生成的输出序列的反馈来训练学生。与监督KD方法不同,GKD还提供了在学生和教师之间使用替代损失函数的灵活性,这在学生缺乏模仿教师分布的表达能力时可能有用。此外,GKD促进了蒸馏与语言模型强化学习(RL)微调的无缝集成。我们展示了GKD在任务特定蒸馏(如摘要、翻译和推理任务)和任务无关蒸馏(如指令微调)中的有效性。
图1:在不同学生模型大小下比较GKD与KD方法。 我们使用经过监督微调(FT)的T5模型(Raffel et al., 2020)作为学生模型,并使用监督微调的T5-XL(约30亿参数)作为教师模型,其性能由水平线表示。监督KD和FT使用真实输出序列进行训练,而SeqKD使用教师生成的输出序列进行训练。在线策略GKD则使用从学生模型采样的输出序列进行训练。对于GKD,我们在WMT任务上使用JSD(0.1),在其他任务上使用前向KL。在评估时,XSum和GSM8K任务使用贪婪采样,WMT任务使用束搜索。
1. 引言
自回归序列模型,如语言模型(LMs),在许多任务中展示了令人印象深刻的能力,其成功的关键通常在于扩展训练数据的数量以及模型参数的数量(Kaplan et al., 2020)。然而,扩展参数数量是有代价的,这些模型的部署受到其推理成本或内存占用的限制。因此,实际使用大型模型的一个关键目标是通过减少参数数量来压缩它们,同时尽可能保留其性能。
模型压缩的流行技术之一是知识蒸馏(Hinton et al., 2015)。蒸馏是训练一个模型(学生)以复制另一个模型(教师)在特定任务上的知识的过程。通常,学生的参数比教师少,因此蒸馏可以在保持较低推理成本和内存占用的同时提高任务特定性能。当前用于自回归序列模型的蒸馏方法要么需要从教师模型生成固定的输出序列集(Kim & Rush, 2016),这可能很昂贵,要么需要一个固定的序列数据集,教师可以通过分配词元级概率来标记这些序列(Sanh et al., 2019)。然而,使用固定数据集可能导致训练期间看到的输出序列与学生自回归推理期间生成的序列之间的分布不匹配,这是模仿学习中的一个众所周知的问题(Pomerleau, 1991; Ross & Bagnell, 2010)。此外,蒸馏的常见目标是最小化教师和学生分布之间的前向KL。然而,学生可能没有足够的表达能力来拟合教师的分布,这可能导致学生生成的样本不太可能由教师生成(例如,图A.16)。
在本文中,我们提出了广义KD(GKD)来缓解上述问题。首先,我们认识到自回归序列模型的KD可以看作是一个具有交互专家的模仿学习问题(Ross et al., 2011)。利用这一见解,GKD训练学生使用其自我生成的序列,这些序列是在线策略的,而不是固定的输出序列集,使用教师概率作为这些序列的专家标签。我们的想法进一步得到了最近在大型语言模型上微调其自身输出序列的成功支持(Ouyang et al., 2022; Singh et al., 2023)。此外,GKD提供了优化替代散度度量的灵活性,例如反向KL和广义JSD(第2节),这些度量可以利用学生的有限容量来生成在教师下可能生成的样本。
GKD统一了一些现有的自回归LM的KD方法,同时实例化了新的在线策略方法,这些方法显著优于流行的方法。在线策略GKD相对于初始学生的性能提升,在不同大小的T5学生模型上平均,我们看到在摘要任务上相对提升了2.1倍,在机器翻译任务上提升了1.7倍,在算术推理任务上提升了1.9倍,相比于基线KD方法实现的性能提升(图1)。此外,我们展示了GKD在任务无关蒸馏中的有效性,在保留的BBH和MMLU基准套件上分别实现了2%和1%的绝对准确率提升(图10)。
我们的主要贡献是:
-
为了解决自回归LM在训练和推理期间的差异,我们提出了GKD,它利用在线策略的学生生成输出进行蒸馏,并由教师在这些输出上的词元级概率指导。GKD在任务特定(图1)和任务无关KD(图10)中显著优于常用方法。
-
我们展示了在线策略GKD可以与语言模型的RL微调(例如,RLAIF)无缝结合,这种结合以前没有被探索过(图5)。
-
通过对GKD中设计选择的系统评估,我们提供了关于在蒸馏期间使用学生生成的在线策略输出序列的重要性和学生与教师之间最佳散度的任务依赖性的实用见解。
2. 预备知识
3. 自回归序列模型的蒸馏
其中期望是对数据集中的样本进行的。这个监督目标通过利用教师的完整词元级分布提供了丰富的训练信号。
广义知识蒸馏(GKD)
如上所述,常用的KD方法使用固定的输出序列数据集,要么使用真实目标,要么使用教师生成的序列。然而,使用这种方法蒸馏自回归学生模型会导致训练-推理分布不匹配。这是因为学生在推理期间的自回归生成阶段遇到的部分序列可能与训练期间看到的序列大不相同。由于自回归模型中的任何步骤的预测都取决于先前的步骤,这种不匹配可能会产生级联效应,其中早期步骤的预测错误会影响未来的预测,导致文本生成质量差。为了解决这种不匹配,我们大量借鉴了模仿学习(IL)。特别是,在线策略模仿方法(例如Ross et al., 2011)迭代地使用学生策略收集序列,获取这些序列的专家标签,然后在这个数据集上重新训练学生。尽管它们在机器人和深度RL中很受欢迎(Parisotto et al., 2015; Kelly et al., 2019; Agarwal et al., 2022),但在线策略方法通常不用于蒸馏自回归模型。
其中我们不通过学生的采样分布pS(⋅∣x)反向传播,类似于在线策略模仿。不通过采样反向传播使训练稳定且计算高效。在在线策略KD中,训练是在学生可能生成的输出序列上进行的。在训练期间,我们使用温度γ=1以鼓励学生生成序列的多样性。此外,给定未标记的输入提示,使用学生生成序列比使用教师生成序列计算成本更低,因为它们的模型大小不同。
在在线策略KD的基础上,我们进一步统一了监督和在线策略方法,并提出了一种更通用的方法,我们称之为广义KD(GKD)。在GKD中,我们可以选择优化的散度以及训练的输出序列。具体来说,我们可以优化教师和学生词元级概率分布之间的任何散度。对于输出序列,GKD使用固定数据集(教师生成或真实数据)和在线策略学生生成序列的混合。抽象地,GKD最小化以下形式的目标:
其中D(p⊤,pS)(y∣x)是教师和学生分布之间的散度(方程2),λ∈[0,1]是一个超参数,控制学生数据分数,即在线策略学生生成输出的比例。类似于在线策略KD,我们不通过学生的采样过程反向传播梯度。在线策略和监督KD是GKD的实例化,散度D设置为前向KL,学生数据分数λ分别设置为1和0。也就是说,GKD允许选择其他分数λ和散度,我们在这项工作中探索了这些选择。
备注。 与随机初始化的学生不同,我们假设访问一个可以生成足够质量序列的学生,教师可以提供反馈。在我们的实验中,我们从经过监督微调的学生模型开始。这类似于广泛用于LM的两阶段RLHF训练,我们首先运行SFT,然后进行在线RL微调。因此,GKD可以利用RLHF的超参数调整见解,并且可以与RLHF结合,计算开销小且无需额外的超参数。
GKD中散度的选择 虽然前向KL通常用于蒸馏,但它要求学生覆盖教师词元级分布pT(⋅∣y<n,x)的整个支持。在这样做时,学生可能会将概率质量分配给在pT(⋅∣y<n,x)下概率较低的词元vv,这可能导致幻觉和低质量生成。当学生的模型容量远低于教师时,这个问题在温度采样时可能会发生(例如,图16)。或者,模式寻求散度,如反向KL,优先考虑教师分配高概率的词元,这可以避免低质量生成,但代价是给定输入的生成多样性减少。我们的实验表明,最佳散度似乎是任务依赖的。总的来说,在为特定任务选择GKD散度时,需要考虑多样性和性能之间的权衡(例如,图4, 10)。
RL微调 + 在线策略GKD
在某些任务中,从教师模型蒸馏可能只提供我们主要目标的代理,该目标也可能是不可微的。我们可以直接使用强化学习(RL)优化这个目标。方便的是,在线策略GKD可以很容易地与来自人类(RLHF)或AI反馈(RLAIF)的RL微调结合,因为它只需要来自学生的输出样本。确实,考虑一个人想要优化学生策略以获得标量奖励r,同时保持接近教师策略,那么我们得到一个正则化的RL微调目标,形式如下:
其中α∈[0,1]控制蒸馏损失与RL目标的强度。当α=1时,它将仅执行蒸馏。上述目标允许我们在通过蒸馏改进其他模型能力的同时最大化奖励,这可能减少在将语言模型与人类偏好对齐时“对齐税”导致的通用模型能力下降(Ouyang et al., 2022)。我们应用上述想法,使用RLAIF减轻幻觉,同时通过蒸馏提高下游性能(图5)。
备注。 在RLHF或RLAIF中,我们通常使用反向KL来约束学习策略以保持接近初始策略。如果一个人只想对现有的RL微调工作流程进行轻微修改,我们建议在将GKD与RL集成时使用反向KL或JSD (0.9)。
4. 实验
在本节中,我们评估了GKD在摘要生成、机器翻译和算术推理等任务中蒸馏语言模型的效果。
学生/教师模型。 我们的实验从不同大小的学生和教师模型开始,具体来说,是开源的T5模型(Raffel et al., 2020),这些模型在相同的数据集上进行了预训练。我们使用经过监督微调的T5-XL(约3B参数)作为教师。对于学生,我们使用T5-small(77M参数)、T5-base(250M参数)和T5-large(800M参数),这些模型分别比教师小38倍、12倍和3.8倍。更多细节见附录A.2。
案例研究:摘要生成
我们首先在生成摘要的任务上评估GKD,该任务旨在生成捕捉输入文档要点的摘要。为此,我们使用XSum数据集(Narayan et al., 2018),该数据集包含新闻文章和人工撰写的摘要。遵循PaLM(Chowdhery et al., 2022),我们使用ROUGE-2分数(Lin, 2004)在XSum的验证集上评估预测摘要的性能,但在ROUGE-L和ROUGE-1上也观察到类似的趋势。我们使用在XSum上监督微调的T5模型作为蒸馏的学生,而微调的T5-XL作为教师。更多实验细节见附录A.3。
与基线的比较。 首先,我们探讨了GKD与广泛使用的KD方法(即SeqKD和Supervised KD)在不同学生模型大小上的比较。如图1所示,我们观察到GKD的持续改进,这展示了GKD在学生容量方面的可扩展性。值得注意的是,GKD允许我们使用比PaLM(540B)小7000倍的T5模型超越其少样本性能。我们还将GKD变体与ImitKD和f-distill进行比较,并在图2中评估了贪婪采样和温度采样(γ=1)下的性能。使用JSD (0.9)的在线策略GKD在这两种情况下都优于这些额外基线。
图4:散度对性能与多样性的影响。 通过使用不同散度的在线策略GKD,我们评估了蒸馏学生生成质量与多样性之间的权衡,并通过调整采样温度来量化多样性。我们使用Self-BLEU(Zhu et al., 2018)来度量多样性,其中得分为100表示完全确定的输出,0表示最大多样性。从前向KL过渡到反向KL,通过广义JSD,会导致多样性降低,这归因于散度的模式寻求特性增强。模式寻求散度通常能生成更高质量的文本,尤其是在高温(γ = 1)下。降低温度会减少多样性,同时缩小不同散度之间的性能差异。
数据效率和扩展。 为了评估GKD的效率和可扩展性,我们使用子采样的XSum训练数据集(1K(0.5%)、10K(5%)和50K(25%)样本)蒸馏了T5-XL教师。我们使用T5-small作为学生,并在图3中报告了数据扩展曲线。值得注意的是,在5%的子采样数据集上使用在线策略GKD,没有任何真实摘要,优于使用整个训练数据集和真实摘要的监督KD和ImitKD。
GKD消融。 我们在图A.12和A.13中对不同学生大小的GKD进行了不同散度和学生数据分数的消融。在线策略和混合变体始终优于监督变体。模式寻求散度在使用温度采样进行评估时表现更好,而在使用贪婪采样时,散度的选择对性能影响不大。
选择GKD散度。 蒸馏中选择的散度在确定摘要质量和多样性之间的权衡方面至关重要。由于采样温度也可以调整以平衡摘要质量和多样性,最佳散度的选择是温度依赖的。为了理解这种依赖性,我们评估了使用不同散度的在线策略GKD蒸馏的T5-small。如图4所示,某些散度,如JSD (0.5)和JSD (0.9),在高温下提供更好的质量但多样性较少。然而,随着温度的降低,散度之间的质量差异缩小,同时多样性也下降。
在线策略GKD与RL。 在摘要生成中,我们希望模型生成的摘要与其输入文档在事实上一致。然而,仅靠蒸馏可能不会提高事实一致性,因为即使是大型模型也会产生幻觉并生成不一致的摘要。最近,Roit et al. (2023)通过使用RL与文本蕴含反馈作为奖励(RLEF)来缓解摘要任务中的幻觉,因为忠实的摘要必须从输入文档中文本蕴含。受其成功的启发,我们探索了将使用类似REINFORCE目标的RL微调与在线策略GKD结合,如第3.2节所述。如图5所示,GKD与RL微调相比教师模型显著提高了事实一致性,同时在蒸馏学生模型的摘要质量上取得了大幅提升。
机器翻译
为了评估GKD在摘要生成之外的效果,我们考虑了使用WMT14 en-de(Bojar et al., 2014)将英语翻译成德语的任务。我们使用BLEU分数在验证集上报告性能,该分数衡量机器翻译文本与高质量参考翻译的相似性。
图5:RLAIF + 在线策略GKD。我们展示了在XSum上奖励最大化和摘要性能之间的权衡。我们报告了相对于原始T5-base学生的改进。遵循Roit et al. (2023),我们使用来自T5-XXL NLI分类器的文本蕴含分数作为奖励。α控制使用JSD (0.9)的在线策略GKD损失的强度。随着α的增加,ROUGE-2增加,而事实一致性的改进减少。为了比较,我们展示了比学生大12倍的T5-XL教师的相对性能。RLEF对应于Roit et al. (2023)的RLAIF方法,其中学生被正则化到原始学生模型本身而不是教师。在线策略GKD + RL相比RLEF实现了更高的ROUGE-2,同时生成比教师更事实一致的摘要。
图6:在WMT en → de上改变学生数据分数和GKD散度。为了评估,我们使用束搜索并报告蒸馏学生相对于原始学生的BLEU分数改进。结果在三个种子中取平均值。我们观察到仅使用学生生成的输出样本优于其他GKD变体。我们使用在WMT上监督微调的T5-XL(约38参数)作为教师,其获得28的BLEU分数。(左)我们使用T5-small(77M参数)作为学生,其获得25.58的BLEU分数。(右)学生对应于T5-base(250M参数),其获得26.98的BLEU分数。
算术推理
Wei et al. (2022) 表明,推理能力仅在具有至少数十亿参数的LLM中出现,这使得KD对于提高较小模型的推理能力变得重要。为此,我们在GSM8K(Cobbe et al., 2021)上评估GKD,这是一个高质量的数学应用题数据集,需要多步逻辑推理。在这里,我们探索GKD与链式思维(CoT)(Wei et al., 2022)的结合,这是一种通过提示LLM生成中间推理步骤来提高其推理能力的常见方法。
设置。 我们通过在GSM8K中的数学问题前添加Wei et al. (2022)中的前4个CoT输入-输出示例来进行少样本提示。为了评估,我们通过检查目标答案是否与使用外部计算器给出的最终答案匹配来报告测试集上的准确性,类似于Cobbe et al. (2021)。对于监督训练,我们使用Magister et al. (2022)生成的CoT输出,导致在GSM8K的原始训练集中大约有5.3K(问题,CoTs)对。我们使用在上述CoT数据集上进行了10K步监督微调的Flan-T5模型作为蒸馏的起点。我们使用微调的FLAN T5-XL作为教师,其在测试集上的准确率为27.9。更多实验细节见附录A.4。
结果。 我们首先对GKD变体进行了消融实验,结果如图7和A.14所示。我们观察到,当仅使用固定的CoT数据集或将其与学生生成的CoT混合时,性能始终不如仅使用学生生成的CoT。此外,前向KL表现良好,类似于我们在XSum上使用贪婪采样的发现。值得注意的是,反向KL也表现良好,尤其是在使用固定数据集进行训练时。此外,图8显示,当学生生成数据的比例超过25%时,性能通常会提高。此外,我们展示了在线策略GKD在所有学生模型大小上均优于基线KD方法,如图9所示。最后,我们在GSM8K上展示了GKD自蒸馏的潜力,结果见附录A.1。
任务无关蒸馏:指令微调
虽然任务特定蒸馏为预定义任务提供了优化的性能,这在部署时通常至关重要,但任务无关蒸馏在任务性质未知且可能在部署期间变化的情况下提供了一个有吸引力的替代方案。正如Sanh et al. (2019)所强调的,任务无关蒸馏的吸引力在于其效率:一旦蒸馏完成,模型可以通过提示或微调重新用于多个下游任务。
设置。 为了研究任务无关KD,我们专注于指令微调(Chung et al., 2022)。我们的目标是增强蒸馏模型处理以指令形式呈现的多样化任务的能力。为此,我们使用FLAN T5-XL模型作为教师,并将其知识蒸馏到FLAN T5-Base中,如Chung et al. (2022)所介绍。我们的蒸馏过程利用了全面的FLAN2021指令微调数据集,该数据集包含5.36百万个示例,涵盖62种不同的语言理解和生成任务。超参数细节见表A.4。
评估。 为了衡量任务无关模型的通用性,必须在一组多样化任务上对其进行测试。遵循Chung et al. (2022),我们在两个保留的基准套件上评估了蒸馏后的T5-base学生模型:(1)MMLU(大规模多任务语言理解)包括来自57个任务的考试问题,如数学、历史、法律和医学;(2)BBH(BIG-Bench Hard)包括23个来自BIG-Bench的任务,这些任务上PaLM 540B(Chowdhery et al., 2022)的表现低于人类评分者的平均水平。对于性能,我们报告了蒸馏模型通过标准少样本提示直接预测答案的能力,在MMLU和BBH中的任务上取平均值。
结果。 我们在图10中报告了各种方法在50K训练步后获得的蒸馏检查点的性能。我们发现,使用反向KL的在线策略GKD显著优于监督KD和ImitKD。值得注意的是,在指令微调的背景下,我们发现使用反向KL比前向KL表现更好。我们假设反向KL在指令微调中的有效性可能源于其模式寻求特性,因为它确保模型专注于指令指定的主要意图或行为。因此,模型可能会优先考虑核心行为而不是不太相关的细节,从而在保留任务上表现更好。
5. 相关工作
知识蒸馏。 监督KD(Bucilua et al., 2006; Hinton et al., 2015)是一种经典方法,已成功用于蒸馏自回归模型(Sanh et al., 2019)。另一种蒸馏此类模型的方法是序列级KD(Kim and Rush, 2016)。在线策略GKD显著优于监督KD和SeqKD(图1)。其他KD方法训练学生以匹配从教师获得的不同量,例如隐藏状态(Jiao et al., 2020)或注意力分数(Wang et al., 2020)。然而,这些方法都没有将蒸馏与模仿学习联系起来,纯粹的监督方法可能会受到训练-推理不匹配的影响,也称为暴露偏差(Ranzato et al., 2015; Bengio et al., 2015)。虽然He et al. (2019)认为这种不匹配可能并不关键,但几篇论文表明,暴露偏差会导致文本生成质量差(Zhang et al., 2019; Chiang and Chen, 2021; Arora et al., 2022)。
ImitKD(Lin et al., 2020)通过从学生和固定数据集中采样序列来识别这种联系,但没有进一步推进这一想法。与GKD不同,ImitKD没有探索纯粹的在线策略数据收集,也没有集成RL微调。此外,ImitKD在词元级别保持前向KL,这在可以访问教师的log-probabilities而不仅仅是样本时是不必要的。此外,GKD展示了该想法的可扩展性,处理了比ImitKD探索的学生模型大约26倍大的模型。ImitKD可以看作是具有前向KL和非递增λλ调度的GKD,一个简单的选择是λ=0.5。最近,f-distill(Wen et al., 2023)将序列级KD公式化为最小化f-散度,并提出了基于学生和教师词元级分布之间总变差距离的可处理目标。本质上,ImitKD和f-distill都是GKD的特定实例,我们展示了它们比在线策略GKD导致更差的实证结果(图2, 9)。
与MiniLLM(Gu et al., 2023)的并发工作也利用了模仿的链接,并将蒸馏框架为RL问题。特别是,MiniLLM在序列级别优化教师和学生之间的反向KL(而似然最大化是前向的)使用策略梯度方法。然而,我们认为GKD更简单且更稳定,更接近监督训练,因为它不通过学生的采样过程反向传播。确实,MiniLLM依赖于许多稳定技巧,以应对高方差、奖励黑客和生成长度偏差。GKD也更通用,因为它也可以与其他散度(如前向KL或JSD)一起使用,这些散度可能比反向KL表现更好(图6, 7)。
RL微调。 现在有许多语言模型通过RL进行微调的例子,无论是优化某些指标的奖励(Wu et al., 2018),还是使用人类反馈学习的奖励(Ouyang et al., 2022)。在这些方法中,通常将RL微调模型正则化到初始(通常是监督微调的)模型。然而,据我们所知,我们是第一个同时进行蒸馏和RL微调的人(图5)。如果这看起来自然,但从优化角度来看,它相当不同,因为它将正则化从初始策略改为教师策略,我们通过实证表明这是一种可行的方法。
带有推理轨迹或理由的蒸馏。 链式思维提示(Nye et al., 2021; Wei et al., 2022)最近表明,LLM可以通过提示逐步解决复杂的推理任务。这个想法很快被适应到KD中,通过扩展教师数据集以包含CoT提示来微调学生(Magister et al., 2022; Ho et al., 2022; Hsieh et al., 2023)。蒸馏仍然以监督方式进行,可以考虑其他类型的增强提示(Li et al., 2022; Mukherjee et al., 2023)。我们采用了相同的方法,但将其与具有各种散度的在线策略蒸馏结合。它展示了GKD的通用性,并改进了纯粹的监督方法,正如我们在GSM8K上的结果所示(图9)。
应用于推测解码。 Zhou et al. (2023)和Liu et al. (2023)应用GKD来改进草稿和目标模型之间的对齐,以提高推测解码的推理速度。
6. 结论
在这项工作中,我们提出了GKD来解决蒸馏自回归语言模型时的训练-推理分布不匹配问题。GKD在三个语言生成任务上始终优于常用的知识蒸馏方法:摘要生成、机器翻译和算术推理。我们进一步展示了GKD可以与强化学习结合,以优化序列级奖励,同时蒸馏大型教师模型的知识,我们相信这可以改进广泛使用的RLHF训练阶段。未来工作的一个有趣方向是将GKD扩展到音频(Radford et al., 2023)、视频(Villegas et al., 2022)和文本到图像生成(Yu et al., 2022)的自回归序列模型。我们希望我们的工作对致力于提高生成自回归序列模型性能和效率的研究人员和从业者有价值。
附录A 附录
A.1 自蒸馏
自蒸馏。 我们研究了GKD是否适用于自蒸馏(Yim et al., 2017),即我们希望将知识从教师模型转移到具有相同架构和大小的学生模型。为了研究这一点,我们考虑了在GSM8K上使用FLAN-T5 large作为学生和教师的自蒸馏,其中教师在GSM8K上进行了监督微调。如图11所示,自蒸馏学生在测试集上的表现超过了教师。此外,使用学生生成数据进行蒸馏优于监督KD,在线策略GKD表现最佳。
A.2 T5模型
作为基础检查点,我们从LM-adapted T5v1.1模型开始。这些LM-adapted模型从T5v1.1初始化,并在T5论文(Raffel et al., 2020)中讨论的LM目标上进行了额外的100K步训练。这些检查点在https://console.cloud.google.com/storage/browser/t5-data/pretrained_models上开源。
在我们的实验中,我们通过在原始训练数据集上运行进一步的监督微调来初始化蒸馏的学生和教师模型,如下所述:
-
XSum。 对于small、base、large和XL模型,我们分别使用LM-Adapted T5v1.1模型进行了100K、50K、38K和8K步的监督微调。
-
WMT。 对于small、base、large和XL模型,我们分别使用LM-Adapted T5v1.1模型进行了250K、250K、110K和50K步的监督微调。
-
GSM8K。 所有模型都从FLAN-T5模型开始,在Palm-540生成的CoT数据集上进行了10K步的监督微调。
与T5和FLAN-T5类似,我们的实验使用Adafactor优化器(Shazeer & Stern, 2018)。
GKD的计算成本。 所有方法,包括基线,都从监督微调的学生检查点开始,这需要在最小的TPUv3(8核)上进行几个小时的训练。在GSM8K上,学生采样的计算开销大约是使用固定输出数据集采样的1.8倍、2倍和2.2倍,学生与教师的比例分别为38倍、12倍和3.8倍。对于RLHF + GKD,计算开销相对较小,因为我们只运行推理以获取教师logits而不是学生logits。
此外,在现实世界用例中,大部分成本是由于推理时的服务成本,而不是微调成本。具体来说,如果在微调期间从学生采样成本太高,那么向用户提供此模型的成本也可能太高(可能从数万到数十亿)。总的来说,在线策略GKD的性能优势可能值得计算成本,尤其是与RLHF结合时。
A.3 XSum
学习率扫描。 我们在{0.0001, 0.0003, 0.001}上进行了学习率扫描,发现0.0003对于T5-base和T5-large最佳,而0.001对于T5-small最佳。因此,默认情况下,我们使用0.0003的学习率,除非报告T5-small的结果,此时使用0.001。我们发现反向KL对较高的学习率更敏感,因此在使用反向KL时,我们默认对所有模型使用0.0003。
教师Softmax温度。 当使用贪婪采样进行评估时,我们将教师温度设置为1。然而,当报告学生使用温度采样(γ=1)的性能时,如图2和图3所示,我们将教师温度设置为0.1。
GKD消融实验。 我们在图A.12和A.13中对不同学生模型大小的GKD进行了不同散度和学生数据分数的消融实验。在线策略和混合变体始终优于监督变体。模式寻求散度在使用温度采样进行评估时表现更好,而在使用贪婪采样时,散度的选择对性能影响不大。
A.4 GSM8K
训练设置。 我们使用Magister et al. (2022)从Palm-540B生成的CoT输出进行训练。我们在GSM8K数据集的原始测试集上报告准确性(Cobbe et al., 2021)。我们使用蒸馏训练结束后的检查点报告结果,结果在3个种子中取平均值。
少样本CoT提示。 以下是实验中使用的4样本CoT提示:
问题: 树林中有15棵树。今天树林工人将在树林中种植树木。完成后,树林中将有21棵树。今天树林工人种植了多少棵树?
答案: 原本有15棵树。种植后共有21棵树。因此,工人种植了21−15=6棵树。答案是6。
问题: 如果停车场有3辆车,又有2辆车到达,停车场有多少辆车?
答案: 原本有3辆车。又有2辆车到达。3+2=5。答案是5。
问题: Leah有32块巧克力,她的妹妹有42块。如果她们吃了35块,她们总共还剩多少块?
答案: 原本Leah有32块巧克力,她的妹妹有42块。因此,她们总共有32+42=74块。吃了35块后,她们剩下74−35=39块。答案是39。
问题: Jason有20根棒棒糖。他给了Denny一些棒棒糖。现在Jason有12根棒棒糖。Jason给了Denny多少根棒棒糖?
答案: Jason原本有20根棒棒糖。给了Denny一些后,他有12根。因此,他给了Denny20−12=8根。答案是8。
A.5 WMT
评估设置。 我们使用与Raffel et al. (2020)相同的束搜索超参数进行评估。我们报告训练结束后最终检查点的性能。为了减少结果的方差,我们报告了3个种子的平均值。
A.6 指令微调
超参数细节。 表A.4中列出了FLAN指令微调的超参数细节。
A.7 模式寻求与模式覆盖KL
模式寻求与模式覆盖KL的对比。 图A.16展示了在容量不匹配的情况下,最小化前向KL和反向KL时学习到的分布。反向KL是模式寻求的,因为它迫使
在PP为零的地方为零,从而使其集中在其中一个模式上(最后一图)。然而,前向KL是模式覆盖的,因为它确保在P有质量的地方,
也有一定的质量。参见Le (2017)以复现此图。
总结
本文提出的广义知识蒸馏(GKD)通过在线策略的学生生成序列进行蒸馏,有效解决了自回归语言模型在训练和推理期间的分布不匹配问题。GKD在摘要生成、机器翻译和算术推理等任务上显著优于现有的知识蒸馏方法,并且可以与强化学习微调无缝结合,进一步提升模型性能。未来,GKD有望扩展到音频、视频和文本到图像生成等领域的自回归序列模型中,为生成模型的压缩和优化提供新的思路。