通过TRL使用DDPO来微调稳定扩散模型
利用TRL测试DDPO技术优化稳定扩散模型
简介
扩散模型(例如,DALL-E 2,稳定扩散)是一类生成模型,它们在生成图片方面取得了广泛成功,尤其是逼真的图片。然而,这些模型生成的图像可能并不总是符合人类的喜好或意图。因此就出现了对齐问题,即如何确保模型的输出与人类喜好(如“质量”)或难以通过提示表达的意图保持一致?这就是强化学习的应用之处。
在大型语言模型(LLM)领域,强化学习(RL)已经证明是将这些模型与人类喜好对齐的非常有效的工具。它是系统如ChatGPT优秀性能的主要因素之一。更准确地说,RL是从人类反馈中进行强化学习(RLHF)的关键因素,使ChatGPT能够像人类一样进行对话。
在Training Diffusion Models with Reinforcement Learning, Black等文章中,展示了如何通过一种名为去噪扩散策略优化(DDPO)的方法,对扩散模型进行改进,利用RL进行微调。
在本文中,我们将讨论DDPO的产生过程,对其工作原理进行简要介绍,以及如何将DDPO纳入RLHF工作流程,以实现与人类审美更加一致的模型输出。然后我们将迅速切入话题,讨论如何使用新集成的DDPOTrainer
从trl
库中将DDPO应用于您的模型,并讨论我们在稳定扩散上运行DDPO的发现。
DDPO的优势
DDPO并不是唯一试图通过RL对扩散模型进行微调的有效方法。
在深入研究之前,有两个关键点需要记住,以理解一种RL解决方案胜过其他解决方案的优势:
- 计算效率至关重要。随着数据分布变得越复杂,计算成本也越高。
- 近似是不错的,但由于近似并非真实的东西,相关误差会不断积累。
在DDPO之前,奖励加权回归(RWR)是一种使用强化学习对扩散模型进行微调的成熟方法。RWR重复使用扩散模型的去噪损失函数,以及从模型本身采样的训练数据和与最终样本关联的奖励相关的每个样本损失加权。该算法忽略了中间的去噪步骤/样本。尽管此方法可行,但需要注意以下两点:
- 通过加权关联损失进行优化,即最大似然目标,是一种近似优化方法。
- 相关的损失不是精确的最大似然目标,而是从重新加权变分约束推导出的近似解。
这两个近似顺序对性能和处理复杂目标的能力都有重大影响。
DDPO以此方法为起点。DDPO不仅关注最终样本,而是将整个去噪过程视为多步马尔可夫决策过程(MDP),在最后接收奖励。这种表述方式,再加上使用固定采样器,使代理策略成为一个各向同性的高斯分布,而不是任意复杂分布。因此,与路径RWR方法使用的最终样本的近似似然度量不同,在这里可以精确计算每个去噪步骤的似然度量,而这非常容易计算( ℓ(μ,σ2;x)=−n2log(2π)−n2log(σ2)−12σ2∑i=1n(xi−μ)2 \ell(\mu, \sigma^2; x) = -\frac{n}{2} \log(2\pi) – \frac{n}{2} \log(\sigma^2) – \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i – \mu)^2 ℓ(μ,σ2;x)=−2nlog(2π)−2nlog(σ2)−2σ21∑i=1n(xi−μ)2 )。
如果您对了解DDPO的详细信息感兴趣,我们鼓励您查阅原始论文和附带的博客文章。
DDPO算法简介
鉴于MDP框架用于建模去噪过程的连续性质和随之而来的其他考虑因素,选择解决优化问题的工具是策略梯度方法。具体来说,是选择了Proximal Policy Optimization (PPO)。整个DDPO算法基本上与Proximal Policy Optimization (PPO)相同,但突出表现出高度定制化的部分是PPO的轨迹采集部分。
下面是一个概括流程的图示:
DDPO和RLHF: 结合以强化美感
RLHF的一般训练方面可以大致分为以下步骤:
- 有监督的微调“基本”模型以适应一些新数据的分布
- 收集喜好数据并使用它训练奖励模型。
- 使用奖励模型作为信号,通过强化学习对模型进行微调。
值得注意的是,在RLHF的上下文中,喜好数据是捕捉人类反馈的主要来源。
当我们将DDPO加入到混合中时,工作流将变为以下形式:
- 从预训练的扩散模型开始
- 收集喜好数据并使用它训练奖励模型。
- 使用奖励模型作为信号,通过DDPO对模型进行微调。
请注意,在后面的步骤列表中,与一般的RLHF工作流程相比,缺少第3步,这是因为经验上已经显示出(您将亲眼所见)不需要该步骤。
为了使扩散模型更符合人类对美的感知概念,我们按照以下步骤进行:
- 从预训练的稳定扩散(SD)模型开始
- 在Aesthetic Visual Analysis(AVA)数据集上使用具有可训练回归头的冻结CLIP模型,以预测人们对输入图像的平均喜好程度。
- 使用美学预测模型作为奖励信号来通过DDPO对稳定性扩散模型进行微调
在继续进入实际运行这些步骤之前,请记住这些步骤,下面的章节将对其进行描述。
使用DDPO训练稳定扩散
设置
要开始,就硬件方面和DDPO实现而言,至少需要访问A100 NVIDIA GPU以成功进行训练。低于此GPU类型的硬件将很快遇到内存不足的问题。
使用pip安装trl
库
pip install trl[diffusers]
这将安装主要库。以下依赖项用于跟踪和图像记录。在安装wandb
后,一定要登录以保存结果到个人帐户
pip install wandb torchvision
注意:您可以选择使用tensorboard
而不是wandb
,您可以通过pip
安装tensorboard
包。
演练
trl
库中负责DDPO训练的主要类是DDPOTrainer
和DDPOConfig
类。有关DDPOTrainer
和DDPOConfig
的更一般信息,请参见文档。在trl
repo中有一个示例训练脚本。它在与默认实现的必需输入和默认参数一起使用这两个类来微调默认预训练的稳定扩散模型。
这个示例脚本使用wandb
进行日志记录,并使用一个美学奖励模型,其权重从公共的HuggingFace仓库读取(因此已经为您完成了数据收集和训练美学奖励模型)。使用的默认提示数据集是动物名称的列表。
用户需要满足一个命令行标志参数,才能启动并运行该脚本。此外,用户需要拥有一个huggingface用户访问令牌,该令牌将用于将模型经过微调后上传到HuggingFace hub。
以下的bash命令可以让事情正常运行:
python stable_diffusion_tuning.py --hf_user_access_token <token>
以下表格包含与正面结果直接相关的关键超参数:
提供的脚本只是一个起点。可以随意调整超参数,甚至重构脚本以适应不同的目标函数。例如,可以整合一个评估JPEG可压缩性的函数,或者使用多模态模型评估视觉文本对齐的函数,还有其他可能性。
经验教训
- 尽管训练提示数量很少,结果似乎可以普遍适用于各种各样的提示。这已经通过奖励美学的目标函数进行了彻底验证。
- 尝试通过增加训练提示数量和变化提示来显式地进行泛化,似乎会减慢收敛速度,而学到的泛化行为几乎无法察觉,如果有的话。
- 尽管LoRA被推荐并多次经过测试,但非LoRA也是值得考虑的,从经验证据中得知,相对于LoRA,非LoRA似乎可以产生更加复杂的图像。然而,为稳定的非LoRA运行获得合适的超参数要困难得多。
- 对于非LoRA的配置参数建议如下:将学习率设置相对较低,大约为
1e-5
,将mixed_precision
设置为None
。
结果
以下是预先微调(左侧)和微调后(右侧)的输出结果,对应于提示bear
,heaven
和dune
(每一行都是一个提示的输出):
限制
- 目前
trl
的DDPOTrainer只能用于微调纯净的SD模型; - 在我们的实验中,我们主要关注的是LoRA,它的效果非常好。我们进行了一些完整训练的实验,可以获得更高质量的结果,但找到合适的超参数更具挑战性。
结论
稳定扩散等扩散模型,在使用DDPO进行微调后,可以显著提高生成图像的质量,这是人类或任何其他度量所感知的,只要适当地将它们概念化为目标函数
DDPO的计算效率以及在优化时不依赖逼近的能力,尤其是相对于实现相同目标的先前方法,使其成为适用于微调稳定扩散等扩散模型的合适选择
trl
库的DDPOTrainer
实现了对SD模型进行微调的DDPO。
我们的实验结果强调了DDPO的强大泛化能力,尽管通过变化提示进行显式泛化的尝试效果不一。找到非LoRA设置的合适超参数的困难也是一个重要的发现。
DDPO是一种将扩散模型与任何奖励函数对齐的有前途的技术,我们希望通过在TRL中发布,使其更加 accessible 可以让整个社区都能使用!
致谢
感谢Chunte Lee为本博文创建的缩略图。