2023-12-19论文笔记
如何更加高效地训练有偏好的 LLMs
之前提到了三种改进 LLMs 对齐的方法:使用 AI 代替人类,优化微调数据以及优化训练流程。本周就这三个方向分别选出一篇代表性论文进行介绍(RLAIF、DPO、LIMA三篇论文之前分别进行过总结,因此本周介绍另外三篇)。
方式一:使用AI替代人类——RRHF(Rank Response to align Human Feedback)
底层思想:这一类方法使用AI模型来替换人工标注偏好数据,或者指导模型训练,代表工作——RLAIF。
现有方法的缺陷
现有的方法是 openai 提出的 SFT —> RM —> PPO 流程。但是其中 PPO 算法对于超参数敏感,训练难以收敛,在训练中需要同时部署多个模型,让普通组织难以自己进行对齐,因此作者提出了新的对齐方法来替代原有的基于 RM 和 PPO 的方法。
方法
该方法的核心思想是直接在 RM 数据上优化 LLMs,让产生被接受回答的概率大于被拒绝回答的概率。具体来说,首先从多种来源收集某个 prompt 的响应(模型本身,chatgpt,GPT-4,人为撰写等),在训练过程中,同样可以改变采样响应的来源,因为 RRHF 本身可以使用任何响应来展示人类的偏好,相比之下 PPO 算法就必须要使用自己产生的响应才行。接着基于对数概率对响应进行评分,然后通过排名损失来将这些分数与人类偏好奖励模型或人类偏好标签的分数进行匹配排序。其中对响应评分的公式、排名损失以及为了实践中的崩溃现象而加入的 SFT loss 的定义分别如下:
值得注意的是在这里的评分中,作者注意到了回复的长度对评分的影响,因此使用了长度的归一化来尽量削弱这种影响。
下图中对比了 PPO 算法和 RRHF 算法的区别:
相比 SFT->RM->PPO 流程,RRHF的优点有以下几个方面:
- 仅需要1到2个模型,而PPO需要4个模型,因此 RRHF 算法更加简单高效。
- 监督微调(SFT)可以被看作是 RRHF 算法的一种特殊形式——只有一个p。
- RRHF 算法可以同时被用作语言模型和奖励模型,直接使用对数概率对响应进行评分则可以实现直接在 RM 数据上优化 LLMs,使用人类标注的分数则可以训练一个奖励模型。
- RRHF 算法可以在较低的训练难度下拟合奖励模型的偏好,达到 PPO 算法的效果,并且避免了 PPO 算法中的复杂性和不稳定性问题。
方法效果
在 Helpful and Harmless 数据集上的测试表明 RRHF 方法和 PPO 方法得到的模型的性能相近。
方式二:优化微调数据——MAYBE ONLY 0.5% DATA IS NEEDED
底层思想:该类方法的核心在于仅仅通过优质数据集的获取和产生,以训练得到一个效果较好的 SFT 模型,而无需进行 RM 和 PPO 的训练,代表工作——LIMA。
工作的动机
现有的大模型都是基于无比庞大的数据进行训练的,就像是图像领域的 backbone 需要在 ImageNet 上进行大量预训练一样,各基座 LLMs 也都是在大规模无标注语料上进行的预训练,这已经极度费时费力费钱了,那么能否如同 CV 领域一样,将预训练交给有钱的大公司来做,后续的 SFT、instruction tuning 等阶段只使用相对少量的数据,让所有组织和个人都能够自行训练呢?这就是这篇论文试图实现的工作。
方法
这篇工作关注的是特定任务的指令微调,因此主要是通过减少那些指令的多样性来实现 Low Training Data 。其核心思想是从现有数据中识别出最有价值、最有代表性的核心样本来帮助模型学习执行下游任务所需的知识,从而仅用少量数据就可以实现跟在全部数据上微调不相上下甚至更好的性能。
首先作者先简单介绍了 sft 和 instruction tuning 之间的区别,如下图所示:
方法流程下图所示,潜在空间用三个矩形表示,一个颜色系列代表一个任务。具有相同色系但不同色调的点,对应于来自同一任务但来自不同数据集的数据,如 NLI 任务有 5 个数据集,因此有 5 种不同的色调。主要分为以下几步:
- 将每个句子编码成embedding向量,并进行均值池化和 L2 归一化的预处理。
- 在潜在空间中,将所有样本点聚类成几个类别。在这一步骤中,作者采用的是嵌入空间中的K-means无监督聚类算法(因为NLP任务本身具有任务边界模糊性,使用标签进行有监督训练不太好),通过这个算法获得每个样本和它对应的集群标签的映射。接着,作者考察一个下游任务的样本在几个集群中出现的频率,并选择出现频率最高的集群的中心点作为该下游任务的分布中心点。然后计算该下游任务中所有样本跟该分布中心点的余弦相似度,并找到与中心点最相似的一个任务样本作为该下游任务的任务中心点。(分布中心点是这个任务数据在嵌入空间的中心,可能并不存在于任务数据中,而任务中心点是一个来自这个任务数据与分布中心点余弦相似度最大的样本)
- 从这些聚类样本中进行采样,找到原始分布中的核心样本。在这一步骤中,作者采用了一种核心集算法——KCentergreedy,该算法的目标是选择K个中心点,然后让数据样本与中心点之间的最大距离最小化。具体来说,作者使用任务样本中心点作为初始中心,输入前面步骤中获得的任务样本的所有句子嵌入,找到距离中心点距离最大的样本作为第二中心点,以此类推,不断选择数据样本中距离中心点集最小距离最大的样本作为新的元素加入中心点集,直到达到目标的K个点。
- 使用这些检索到的样本来进行 instruction tuning。
方法效果
对于特定任务的模型,使用该方法只需要使用不到 0.5% 的数据就可以实现 2% 的性能提升。
方式三:优化训练流程——RAFT(Reward rAnked FineTuning)
底层思想:该类方法通常通过改造模型的训练方式(如只保留 SFT 和 RM),以提高训练效率并减少训练成本,代表工作——DPO。
现有方法的缺陷
现有的 LLMs 都需要在大规模无监督训练数据上进行预训练,这样就隐含了一个潜在的问题——这样的基座模型容易受到数据中隐含的偏见的影响,导致生成低质、失真的结果。因此对齐过程是必不可少的,但是现有的对齐过程通常采用强化学习方法,往往低效且训练不稳定。因此作者提出了 RAFT ,通过筛选出高质量样本进行训练来提高模型的性能。
方法
该算法的伪代码如下图所示,整个流程分为三个阶段:数据收集、数据排序、模型微调;这三个阶段可以单独实施和执行。 因此,只要计算资源和显存允许在某些特定模型上进行 SFT,对齐过程就可以使用 RAFT 完成。在数据收集阶段,需要从 prompt 集合中采样一批样本,然后对于每个 prompt,让大模型分别生成响应,并计算这些样本的 reward。在数据排序阶段,要对这些样本排序,并选择指定百分比的具有最高奖励的样本作为训练样本。最后,在模型微调阶段,要使用这些筛选过的样本对模型进行微调。
不难发现,在这个过程中,采样训练数据的过程和模型训练是完全解耦的;并且抽样过程不需要任何梯度计算,可以方便地进行采样训练期间计算资源和内存管理。更进一步地,RAFT的三个步骤,即数据收集、数据排序、模型微调,可以单独实施和执行。 因此,只要计算资源和显存允许在某些特定模型上进行 SFT,对齐过程可以使用 RAFT 完成,并且可以使用批量推理和模型并行来加速模型的训练。
相比 PPO 算法的优势所在
- RAFT 更像是 SFT 的训练,超参数更少且训练更稳定
- 降低了内存负担,因为数据生成和模型微调是脱离的
- 如果有奖励模型作为质量的评判者,那么该方法可以灵活地训练任意生成模型,包括 LLMs 和 sd 模型
- 该方法在优化的时候更看重对样本的偏好顺序,而非具体的奖励值,避免了模型通过一些技巧来欺骗奖励函数以获取更高的奖励的现象(这被称为奖励欺骗)。