使用 LoRA 进行高效稳定的扩散微调
'Efficient and stable fine-tuning of diffusion using LoRA.'
LoRA:低秩适应大型语言模型是微软研究人员提出的一种新技术,用于解决大型语言模型微调的问题。具有数十亿参数的强大模型,如GPT-3,在微调以适应特定任务或领域方面成本太高。LoRA建议在每个Transformer块中冻结预训练模型的权重,并注入可训练的层(秩分解矩阵)。这大大减少了可训练参数和GPU内存需求,因为大多数模型权重不需要计算梯度。研究人员发现,通过专注于大型语言模型的Transformer注意力块,使用LoRA进行微调的质量与完全模型微调相当,同时速度更快,计算需求更少。
LoRA适用于扩散器 🧨
尽管LoRA最初是针对大型语言模型并在Transformer块上进行演示的技术,但该技术也可以应用于其他领域。在稳定扩散微调的情况下,LoRA可以应用于将图像表示与描述图像的提示之间相关的交叉注意力层。下图的细节(取自稳定扩散论文)不重要,只需注意黄色块负责建立图像和文本表示之间的关系。
据我们所知,Simo Ryu(@cloneofsimo
)是首个开发适用于稳定扩散的LoRA实现的人。请大家查看他们的GitHub项目,了解示例和许多有趣的讨论和见解。
为了在模型中注入LoRA可训练矩阵,就像在交叉注意力层中一样深入,人们过去需要以富有想象力(但脆弱)的方式修改扩散器的源代码。如果稳定扩散向我们展示了一件事,那就是社区总能想出用于创造性目的的方式来弯曲和调整模型,我们喜欢这一点!提供操纵交叉注意力层的灵活性可能有很多其他好处,比如更容易采用如xFormers等优化技术。Prompt-to-Prompt等创意项目可能需要一种简单的方式来访问这些层,因此我们决定为用户提供一种通用方法。自去年12月末以来,我们一直在测试该拉取请求,并在昨天与我们的扩散器发布正式推出。
我们一直在与@cloneofsimo
合作,在扩散器中提供LoRA训练支持,包括Dreambooth和完全微调方法!这些技术提供以下好处:
- 培训速度更快,如前面讨论的。
- 计算需求更低。我们可以使用2080 Ti和11 GB VRAM创建一个完全微调的模型!
- 训练权重要小得多。因为原始模型被冻结,我们注入新的可训练层,可以将新层的权重保存为一个约3 MB大小的单个文件。这比UNet模型的原始大小小一千倍!
我们对最后一点特别兴奋。为了让用户共享他们精心微调或Dreambooth的模型,他们必须共享最终模型的完整副本。其他想要尝试这些模型的用户必须在其喜欢的用户界面下载微调的权重,从而导致了大量的存储和下载成本。截至今天,在Dreambooth概念库中注册了约1,000个Dreambooth模型,可能还有更多未在库中注册。
使用LoRA,现在可以发布一个仅为3.29 MB的单个文件,让他人使用您的微调模型。
(感谢@mishig25
,我在正常对话中第一次听到使用dreamboothing这个词作为动词的人)。
LoRA微调
稳定扩散的完全模型微调曾经缓慢且困难,这也是Dreambooth或文本逆转等轻量级方法变得如此流行的原因之一。使用LoRA,可以更容易地在自定义数据集上微调模型。
现在,Diffusers提供了一个LoRA微调脚本,可以在低至11 GB的GPU RAM中运行,而无需使用8位优化器等技巧。以下是您使用它在Lambda Labs Pokémon数据集上微调模型的方法:
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
export HUB_MODEL_ID="pokemon-lora"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--push_to_hub \
--hub_model_id=${HUB_MODEL_ID} \
--report_to=wandb \
--checkpointing_steps=500 \
--validation_prompt="Totoro" \
--seed=1337
需要注意的一点是学习率为1e-4
,远大于常规微调的学习率(通常在~1e-6
的数量级)。这是上一次运行的W&B仪表盘,使用了2080 Ti GPU(11 GB RAM)运行了约5小时。我没有尝试优化超参数,所以你可以自由尝试!Sayak在T4上进行了另一次运行(16 GB RAM),这是他的最终模型,这是使用它的演示空间。
有关扩散器中LoRA支持的详细信息,请参阅我们的文档 – 它将始终与实现保持最新。
推理
正如我们所讨论的,LoRA的主要优势之一是,通过训练比原始模型大小少几个数量级的权重,您可以获得出色的结果。我们设计了一个推理过程,允许在未修改的Stable Diffusion模型权重之上加载附加权重。让我们看看它是如何工作的。
首先,我们将使用Hub API自动确定用于微调LoRA模型的基础模型。从Sayak的模型开始,我们可以使用以下代码:
from huggingface_hub import model_info
# LoRA权重约为3 MB
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
info = model_info(model_path)
model_base = info.cardData["base_model"]
print(model_base) # CompVis/stable-diffusion-v1-4
这段代码将打印他用于微调的模型,即CompVis/stable-diffusion-v1-4
。在我的情况下,我从Stable Diffusion的1.5版本开始训练我的模型,所以如果您使用我的LoRA模型运行相同的代码,您将看到输出为runwayml/stable-diffusion-v1-5
。
关于基础模型的信息是由我们在前一节中看到的微调脚本自动填充的,如果您使用--push_to_hub
选项。这将记录为模型存储库的README
文件中的元数据标签,您可以在这里查看。
在确定了我们用于LoRA微调的基础模型之后,我们加载一个普通的Stable Diffusion管道。我们将使用DPMSolverMultistepScheduler
进行自定义,以实现非常快速的推理:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
这里是魔法发生的地方。我们从Hub加载LoRA权重并放在常规模型权重之上,将管道移到cuda设备上并进行推理:
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")
image = pipe("绿色的有威胁表情的宝可梦", num_inference_steps=25).images[0]
image.save("green_pokemon.png")
使用 LoRA 进行 Dreamboothing
Dreambooth 允许您向稳定扩散模型“教授”新概念。LoRA 与 Dreambooth 兼容,其过程类似于微调,并具有一些优势:
- 训练速度更快。
- 我们只需要一些主题图像来进行训练(通常 5 或 10 张足够)。
- 如果需要,我们可以调整文本编码器,以增加对主题的准确性。
要使用 LoRA 训练 Dreambooth,您需要使用这个扩散器脚本。请查看 README、文档和我们的超参数探索博文以获取详细信息。
想要使用 LoRA 快速、廉价和简单地训练 Dreambooth 模型,请查看由 hysts
创建的这个空间。您需要复制它并分配一个 GPU 以便快速运行。这个过程将使您无需设置自己的训练环境,并且能够在几分钟内训练您的模型!
其他方法
寻找简单微调的方法并非新鲜事。除了 Dreambooth,文本反演是另一种常用的方法,试图将新概念教授给经过训练的稳定扩散模型。使用文本反演的主要原因之一是训练后的权重也很小且易于共享。然而,它们仅适用于单个主题(或少数几个主题),而 LoRA 可用于通用微调,这意味着它可以适应新的领域或数据集。
Pivotal Tuning 是一种尝试将文本反演与 LoRA 结合的方法。首先,使用文本反演技术教授模型一个新概念,获取用于表示它的新标记嵌入。然后,使用 LoRA 训练该标记嵌入,以兼顾两者的优点。
我们尚未探索使用 LoRA 进行 Pivotal Tuning。谁愿意接受这个挑战呢?🤗