REFT: Reasoning with REinforced Fine-Tuning
研究内容为,通过监督微调(SFT)和强化学习RL(PPO算法)结合,来提高大语言模型在解决数学问题方面的推理能力。
ReFT 由两个阶段组成:预热(Warm-up)阶段和强化学习RL阶段。首先使用 SFT 对模型进行预热,然后采用在线强化学习(在该工作中具体是 PPO 算法)进行优化。
预热阶段(Warm-up)
Warm-up是ReFT的初始步骤,其目的是为模型提供一个基础,使其能够生成对数学问题的基本正确响应。这个阶段使用监督式微调SFT实现:
这一阶段使用包含“Question”和“思维链CoT”元组的数据集:(x, e)。模型在这些“Question-CoT”对上进行微调,通常持续1-2个epoch。这个过程将模型的预测能力调整到能够生成适当的响应。
RL阶段
在预热阶段之后,模型进入强化学习阶段,这个阶段使用PPO(Proximal Policy Optimization)算法来进一步提升模型的性能。
这一阶段使用包含“Question”和“Answer”元组(x,y)组成的数据集。
具体来说,模型通过反复生成多种可能的CoT推理路径,还有一个评估器,专门评估响应的答案正确性,生成reward信号反馈。正确答案会给予正奖励,错误答案则不给予奖励。
这个过程,类似于AlphaZero在围棋领域的自对弈(self-play)学习。
从结果上看,ReFT在所有数据集上都显示出比SFT更好的性能,特别是在CodeLLAMA模型上,ReFT在GSM8K数据集上的准确率比SFT提高了近10个百分点。
论文地址:https://arxiv.org/pdf/2401.08967