批注扩散模型
批注扩散模型 (Annotation Diffusion Model)
在本博客文章中,我们将深入研究去噪扩散概率模型(也称为DDPM,扩散模型,基于分数的生成模型或简单的自编码器),研究人员已经能够在(非)条件图像/音频/视频生成方面取得了显著的结果。目前流行的示例包括OpenAI的GLIDE和DALL-E 2,University of Heidelberg的Latent Diffusion和Google Brain的ImageGen。
我们将逐步介绍PyTorch中的原始DDPM论文,根据Phil Wang的实现进行实现,该实现本身基于原始的TensorFlow实现。请注意,生成建模的扩散思想实际上已经在2015年的Sohl-Dickstein等人的论文中引入。然而,直到2019年的Song等人(斯坦福大学)和2020年的Ho等人(Google Brain)独立改进了这种方法。
请注意,有关扩散模型存在几个视角。在这里,我们采用离散时间(潜变量模型)的视角,但还请务必查看其他视角。
好的,让我们开始吧!
from IPython.display import Image
Image(filename='assets/78_annotated-diffusion/ddpm_paper.png')
首先,我们将安装和导入所需的库(假设您已经安装了PyTorch)。
!pip install -q -U einops datasets matplotlib tqdm
import math
from inspect import isfunction
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
什么是扩散模型?
与其他生成模型(如归一化流、GAN或VAE)相比,(去噪)扩散模型并不复杂:它们都将噪声从某个简单的分布转换为数据样本。在这里,情况也是如此,一个神经网络逐渐学习去噪数据,从纯噪声开始。
稍微详细介绍一下图片的情况,设置包括2个过程:
- 我们选择的固定(或预定义的)前向扩散过程 q q q,逐渐向图像添加高斯噪声,直到最终得到纯噪声
- 一个学习到的反向去噪扩散过程 p θ p_\theta p θ ,其中一个神经网络被训练成从纯噪声开始逐渐去噪图像,直到最终得到实际图像。
正向和反向过程都在某个有限时间步 t t t(DDPM作者使用 T = 1000 T=1000 T = 1 0 0 0 )内发生。您从 t = 0 t=0 t = 0 开始,从数据分布中采样一个真实图像 x 0 \mathbf{x}_0 x 0 (假设是ImageNet上的一张猫的图像),正向过程在每个时间步 t t t 中从高斯分布中采样一些噪声,将其添加到上一个时间步的图像中。通过在每个时间步骤中以适当的噪声添加计划,使用足够大的 T T T,您最终通过渐进的过程得到一个称为各向同性高斯分布的图像,即 t = T t=T t = T。
更具数学形式的表达
让我们更正式地写下来,因为我们最终需要一个可计算的损失函数,我们的神经网络需要优化。
设 q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ) 是实际数据分布,比如“真实图像”。我们可以从该分布中采样得到一张图像,x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x 0 ∼ q ( x 0 ) 。我们定义前向扩散过程 q ( x t ∣ x t − 1 ) q(\mathbf{x}_t | \mathbf{x}_{t-1}) q ( x t ∣ x t − 1 ) ,它在每个时间步 t t t 中根据已知的方差计划 0 < β 1 < β 2 < . . . < β T < 1 0 < \beta_1 < \beta_2 < … < \beta_T < 1 0 < β 1 < β 2 < . . . < β T < 1 向图像添加高斯噪声,即 q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) 。 q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 – \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}). q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) 。
回想一下,正态分布(也称为高斯分布)由两个参数定义:均值 μ \mu μ 和方差 σ 2 ≥ 0 \sigma^2 \geq 0 σ 2 ≥ 0 。基本上,每个时间步 t t t 的新图像(稍微有些噪声)都是从一个具有 μ t = 1 − β t x t − 1 \mathbf{\mu}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} μ t = 1 − β t x t − 1 和 σ t 2 = β t \sigma^2_t = \beta_t σ t 2 = β t 的条件高斯分布中抽取的,我们可以通过采样 ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) 然后设置 x t = 1 − β t x t − 1 + β t ϵ \mathbf{x}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} x t = 1 − β t x t − 1 + β t ϵ 来实现。
请注意,β t \beta_t β t 在每个时间步 t t t 上不是常数(因此有下标)— 实际上,我们定义了一个所谓的“方差进度表”,可以是线性、二次、余弦等等,我们将在稍后看到(有点像学习率进度表)。
所以从 x 0 \mathbf{x}_0 x 0 开始,我们最终得到 x 1 , . . . , x t , . . . , x T \mathbf{x}_1, …, \mathbf{x}_t, …, \mathbf{x}_T x 1 , . . . , x t , . . . , x T ,其中如果我们适当地设置进度表,x T \mathbf{x}_T x T 就是纯高斯噪声。
现在,如果我们知道条件分布 p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ∣ x t ) ,那么我们可以反向运行该过程:通过采样一些随机高斯噪声 x T \mathbf{x}_T x T ,然后逐渐“去噪”,使得最终得到一个来自真实分布 x 0 \mathbf{x}_0 x 0 的样本。
然而,我们不知道 p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ∣ x t ) 。这是不可计算的,因为它要求知道所有可能图像的分布,才能计算这个条件概率。因此,我们将利用神经网络来近似(学习)这个条件概率分布,我们将其称为 p θ ( x t − 1 ∣ x t ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) p θ ( x t − 1 ∣ x t ) ,其中 θ \theta θ 是神经网络的参数,通过梯度下降进行更新。
好的,所以我们需要一个神经网络来表示反向过程的(条件)概率分布。如果我们假设这个反向过程也是高斯的,那么回想一下,任何高斯分布都由两个参数定义:
- 一个由 μ θ \mu_\theta μ θ 参数化的均值;
- 一个由 Σ θ \Sigma_\theta Σ θ 参数化的方差;
因此,我们可以将该过程参数化为 p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t)) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) ,其中均值和方差还受到噪声水平 t t t 的影响。
因此,我们的神经网络需要学习/表示均值和方差。然而,DDPM的作者决定保持方差不变,并让神经网络只学习(表示)这个条件概率分布的均值 μ θ \mu_\theta μ θ 。根据论文:
首先,我们将 Σ θ ( x t , t ) = σ t 2 I \Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I} Σ θ ( x t , t ) = σ t 2 I 设置为未经训练的时间相关常数。实验上,两个参数 σ t 2 = β t \sigma^2_t = \beta_t σ t 2 = β t 和 σ t 2 = β ~ t \sigma^2_t = \tilde{\beta}_t σ t 2 = β ~ t (见论文)具有类似的结果。
这在改进的扩散模型论文中得到了改进,其中神经网络除了学习均值之外,还学习了这个反向过程的方差。
因此,我们继续假设我们的神经网络只需要学习/表示这个条件概率分布的均值。
通过重新参数化均值定义目标函数
为了得到一个学习反向过程均值的目标函数,作者们观察到 q q q 和 p θ p_\theta p θ 的组合可以被视为变分自动编码器(VAE)(Kingma等人,2013)。因此,可以使用变分下界(也称为ELBO)来最小化关于真实数据样本 x 0 \mathbf{x}_0 x 0 的负对数似然(有关ELBO的详细信息,请参考VAE论文)。事实证明,这个过程的ELBO是每个时间步 t t t 的损失之和,L = L 0 + L 1 + . . . + L T L = L_0 + L_1 + … + L_T L = L 0 + L 1 + . . . + L T 。通过正向 q q q 过程和反向过程的构建,每个损失项(除了 L 0 L_0 L 0 )实际上是两个高斯分布之间的KL散度,可以明确地写成关于均值的L2损失!
正向过程 q q q 的构造的一个直接结果,正如Sohl-Dickstein等人所示,我们可以在任意噪声水平下以 x 0 \mathbf{x}_0 x 0 为条件进行 x t \mathbf{x}_t x t 的采样(因为高斯分布的和也是高斯分布)。这非常方便:我们不需要重复应用 q q q 以采样 x t \mathbf{x}_t x t 。我们有 q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t | \mathbf{x}_0) = \cal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1- \bar{\alpha}_t) \mathbf{I}) q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I )
其中 α t : = 1 − β t \alpha_t := 1 – \beta_t α t : = 1 − β t 和 α ˉ t : = Π s = 1 t α s \bar{\alpha}_t := \Pi_{s=1}^{t} \alpha_s α ˉ t : = Π s = 1 t α s 。让我们将这个方程称为“良好的特性”。这意味着我们可以采样高斯噪声并适当缩放并将其添加到 x 0 \mathbf{x}_0 x 0 中,直接得到 x t \mathbf{x}_t x t 。注意,α ˉ t \bar{\alpha}_t α ˉ t 是已知的 β t \beta_t β t 方差调度函数的函数,因此也是已知的并且可以预先计算。这样,在训练过程中,我们可以优化损失函数 L L L 的随机项(或者换句话说,在训练过程中随机采样 t t t 并优化 L t L_t L t )。
这个属性的另一个优点是,正如Ho等人所示,通过一些数学计算(我们将读者引用到这篇优秀的博客文章中),可以将均值重新参数化,使神经网络学习(预测)添加的噪声(通过网络ϵθ(xt, t))t t t噪声水平在构成损失的KL项中。这意味着我们的神经网络变成了一个噪声预测器,而不是(直接的)均值预测器。均值可以计算如下:
μθ(xt, t) = 1/αt(xt-βt(1-ᾱt)ϵθ(xt, t))
最终的目标函数Lt则如下所示(给定随机时间步t和ϵ∼N(0, I)):
∥ϵ-ϵθ(xt, t)∥2 = ∥ϵ-ϵθ(ᾱtxt + (1-ᾱt)ϵ, t)∥2
这里,x0是初始(真实的、未损坏的)图像,我们看到由固定的前向过程给出的直接噪声水平ttt样本。ϵ是在时间步ttt采样的纯噪声,而ϵθ(xt, t)是我们的神经网络。神经网络使用真实和预测的高斯噪声之间的均方误差(MSE)进行优化。
现在的训练算法如下:
换句话说:
- 我们从真实未知且可能复杂的数据分布q(x0)中随机抽取一个样本x0
- 我们在1到T之间均匀地随机抽取一个噪声水平t(即一个随机时间步)
- 我们从高斯分布中采样一些噪声,并通过该噪声在水平t上损坏输入(使用上面定义的好的属性)
- 神经网络通过损坏的图像xt(即在已知计划βt上施加在x0上的噪声)来预测这个噪声
实际上,所有这些都是在数据批次上完成的,因为人们使用随机梯度下降来优化神经网络。
神经网络
神经网络需要在特定时间步骤接收一个带噪声的图像,并返回预测的噪声。需要注意的是,预测的噪声是一个与输入图像具有相同大小/分辨率的张量。因此从技术上讲,网络接收和输出具有相同形状的张量。我们可以使用什么类型的神经网络来实现这个功能呢?
在这里通常使用的是与自编码器非常相似的方法,你可能还记得典型的“深度学习入门”教程中的自编码器。自编码器在编码器和解码器之间有一个所谓的“瓶颈”层。编码器首先将图像编码为一个较小的隐藏表示,称为“瓶颈”,然后解码器将该隐藏表示解码回实际图像。这迫使网络只保留瓶颈层中最重要的信息。
在网络架构方面,DDPM作者选择了一种名为U-Net的模型,该模型由Ronnerberger等人于2015年提出(在当时,该模型已经在医学图像分割方面取得了最先进的结果)。这个网络,像任何自编码器一样,由一个中间的瓶颈层组成,确保网络只学习最重要的信息。重要的是,它在编码器和解码器之间引入了残差连接,极大地改善了梯度流动(受到He等人在2015年的ResNet的启发)。
如图所示,U-Net模型首先对输入进行下采样(即在空间分辨率上使输入变小),然后进行上采样。
下面,我们逐步实现这个网络。
网络辅助函数
首先,我们定义一些辅助函数和类,这些函数和类将在实现神经网络时使用。重要的是,我们定义了一个Residual
模块,它简单地将输入添加到特定函数的输出上(换句话说,为特定函数添加了一个残差连接)。
我们还为上采样和下采样操作定义了别名。
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out=None):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
)
def Downsample(dim, dim_out=None):
# No More Strided Convolutions or Pooling
return nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1),
)
位置嵌入
由于神经网络的参数在时间(噪声水平)上是共享的,作者使用了受Transformer(Vaswani等人,2017年)启发的正弦位置嵌入来对t t t进行编码。这使得神经网络“知道”它正在操作的特定时间步骤(噪声水平),对于批量中的每个图像都是如此。
SinusoidalPositionEmbeddings
模块接收一个形状为(batch_size, 1)
的张量作为输入(即批量中几个带噪声图像的噪声水平),并将其转换为形状为(batch_size, dim)
的张量,其中dim
是位置嵌入的维度。然后将其添加到每个残差块中,我们将在后面看到。
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
ResNet块
接下来,我们定义U-Net模型的核心构建块。DDPM的作者使用了一个Wide ResNet块(Zagoruyko等,2016年),但Phil Wang将标准卷积层替换为“加权标准化”版本,这与组归一化结合效果更好(有关详细信息,请参阅Kolesnikov等,2019年)。
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
据称加权标准化与组归一化协同工作
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
mean = reduce(weight, "o ... -> o 1 1 1", "mean")
var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
normalized_weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(
x,
normalized_weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
"""https://arxiv.org/abs/1512.03385"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
if exists(time_emb_dim)
else None
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, "b c -> b c 1 1")
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
注意力模块
接下来,我们定义了注意力模块,DDPM的作者将其添加在卷积块之间。注意力是著名的Transformer架构的基本构建块(Vaswani等,2017年),在AI的各个领域(从自然语言处理和视觉到蛋白质折叠)都取得了巨大成功。Phil Wang采用了2种注意力的变体:一种是常规的多头自注意力(与Transformer中使用的相同),另一种是线性注意力变体(Shen等,2018年),其时间和内存需求与序列长度线性相关,而不是常规注意力的平方关系。
关于注意力机制的详细解释,请参考Jay Allamar的精彩博文。
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)
组归一化
DDPM的作者将U-Net的卷积/注意力层与组归一化(Wu等人,2018年)交叉连接。下面,我们定义一个PreNorm
类,它将用于在注意力层之前应用组归一化,正如我们将在后面看到的那样。需要注意的是,关于在Transformer中是在注意力之前还是之后应用归一化的问题一直存在争议。
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
条件U-Net
现在,我们已经定义了所有的构建模块(位置嵌入、ResNet块、注意力和组归一化),现在是时候定义整个神经网络了。回想一下,网络 ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ( x t , t ) 的任务是接收一批带有噪声的图像及其各自的噪声级别,并输出添加到输入的噪声。更正式地说:
- 网络接收形状为
(batch_size, num_channels, height, width)
的一批带噪声的图像和形状为(batch_size, 1)
的噪声级别的批次作为输入,并返回形状为(batch_size, num_channels, height, width)
的张量
网络的构建如下:
- 首先,在一批带有噪声的图像上应用卷积层,并为噪声级别计算位置嵌入
- 接下来,应用一系列的下采样阶段。每个下采样阶段由2个ResNet块+组归一化+注意力+残差连接+下采样操作组成
- 在网络的中间,再次应用ResNet块,交替使用注意力
- 接下来,应用一系列的上采样阶段。每个上采样阶段由2个ResNet块+组归一化+注意力+残差连接+上采样操作组成
- 最后,应用一个ResNet块,然后是一个卷积层。
最终,神经网络将层堆叠起来,就像乐高积木一样(但重要的是要理解它们是如何工作的)。
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
self_condition=False,
resnet_block_groups=4,
):
super().__init__()
# 确定维度
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # 从7,3更改为1和0
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# 时间嵌入
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
# 层
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out)
if not is_last
else nn.Conv2d(dim_in, dim_out, 3, padding=1),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in)
if not is_last
else nn.Conv2d(dim_out, dim_in, 3, padding=1),
]
)
)
self.out_dim = default(out_dim, channels)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x, time, x_self_cond=None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim=1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
定义前向扩散过程
前向扩散过程逐步向图像从真实分布中添加噪声,总共进行 T 个时间步骤。这是根据一个方差调度进行的。原始的 DDPM 作者采用了一个线性调度:
我们将前向过程的方差设置为从 β 1 = 1 0 − 4 \beta_1 = 10^{−4} β 1 = 1 0 − 4 增加线性到 β T = 0.02 \beta_T = 0.02 β T = 0 . 0 2 。
然而,(Nichol et al., 2021) 表明,当采用余弦调度时,可以获得更好的结果。
下面,我们为 T 个时间步骤定义了各种调度(稍后我们将选择其中一个)。
def cosine_beta_schedule(timesteps, s=0.008):
"""
根据 https://arxiv.org/abs/2102.09672 提议的余弦调度
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
首先,让我们使用线性调度进行 T = 300 T=300 T = 3 0 0 的时间步骤,并定义我们将需要的各种 β t \beta_t β t 相关的变量,比如方差的累积积 α ˉ t \bar{\alpha}_t α ˉ t 。下面的每个变量都是存储从 t t t 到 T T T 的值的一维张量。重要的是,我们还定义了一个 extract
函数,它将允许我们提取批量索引的适当 t t t 索引。
timesteps = 300
# 定义 beta 调度
betas = linear_beta_schedule(timesteps=timesteps)
# 定义 alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# 对扩散 q(x_t | x_{t-1}) 和其他计算
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# 对后验 q(x_{t-1} | x_t, x_0) 的计算
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
我们将以一张猫的图片说明在扩散过程的每个时间步骤中如何添加噪声。
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL 图片,形状为 HWC
image
噪声是添加到PyTorch张量上,而不是Pillow图像。我们首先定义图像转换,使我们能够从PIL图像转换为PyTorch张量(在其上可以添加噪声),反之亦然。
这些转换非常简单:我们首先通过除以255 255 2 5 5(使其在[ 0 , 1 ] [0,1] [ 0 , 1 ]范围内)来标准化图像,然后确保它们在[ − 1 , 1 ] [-1, 1] [ − 1 , 1 ]范围内。来自DPPM论文:
我们假设图像数据由线性缩放到 [ − 1 , 1 ] [-1, 1] [ − 1 , 1 ]的{ 0 , 1 , . . . , 255 } \{0, 1, … , 255\} { 0 , 1 , . . . , 2 5 5 }中的整数组成。这确保了神经网络反向过程始于标准正态先验 p ( x T ) p(\mathbf{x}_T ) p ( x T ) 的统一缩放输入。
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
image_size = 128
transform = Compose([
Resize(image_size),
CenterCrop(image_size),
ToTensor(), # 转换为形状为 CHW 的 torch Tensor,除以 255
Lambda(lambda t: (t * 2) - 1),
])
x_start = transform(image).unsqueeze(0)
x_start.shape
我们还定义了逆转换,它接受包含值在[ − 1 , 1 ] [-1, 1] [ − 1 , 1 ]范围内的PyTorch张量,并将它们转换回PIL图像:
import numpy as np
reverse_transform = Compose([
Lambda(lambda t: (t + 1) / 2),
Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
Lambda(lambda t: t * 255.),
Lambda(lambda t: t.numpy().astype(np.uint8)),
ToPILImage(),
])
让我们验证一下:
reverse_transform(x_start.squeeze())
现在我们可以按照论文定义正向扩散过程:
# 正向扩散(使用好的性质)
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
让我们在特定的时间步骤上进行测试:
def get_noisy_image(x_start, t):
# 添加噪声
x_noisy = q_sample(x_start, t=t)
# 转换回PIL图像
noisy_image = reverse_transform(x_noisy.squeeze())
return noisy_image
# 选择时间步骤
t = torch.tensor([40])
get_noisy_image(x_start, t)
让我们对不同的时间步骤进行可视化:
import matplotlib.pyplot as plt
# 使用种子以便重现结果
torch.manual_seed(0)
# 来源:https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# 即使只有1行,也要创建一个2D网格
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [image] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='原始图像')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])
这意味着我们现在可以根据模型定义以下损失函数:
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = denoise_model(x_noisy, t)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
上述denoise_model
将是我们上面定义的U-Net模型。我们将使用真实噪声和预测噪声之间的Huber损失。
定义PyTorch数据集+数据加载器
在这里,我们定义一个常规的PyTorch数据集。该数据集只包含来自真实数据集(如Fashion-MNIST、CIFAR-10或ImageNet)的图像,线性缩放到[−1,1]范围内。
每个图像都被调整为相同的大小。有趣的是,图像还会随机水平翻转。从论文中可以看到:
我们在CIFAR10的训练过程中使用了随机水平翻转;我们尝试了使用和不使用翻转进行训练,并发现翻转可以稍微提高样本质量。
在这里,我们使用🤗 Datasets库来方便地从hub加载Fashion MNIST数据集。该数据集由分辨率相同的图像组成,即28×28。
from datasets import load_dataset
# 从hub加载数据集
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128
接下来,我们定义一个函数,它将在整个数据集上进行即时应用。我们使用with_transform
功能实现。该函数只是应用了一些基本的图像预处理:随机水平翻转、缩放,并将值调整到[−1,1]范围内。
from torchvision import transforms
from torch.utils.data import DataLoader
# 定义图像转换(例如使用torchvision)
transform = Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
])
# 定义函数
def transforms(examples):
examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
del examples["image"]
return examples
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# 创建数据加载器
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())
采样
由于我们将在训练过程中从模型中进行采样(以跟踪进展),因此我们在下面定义了相关代码。采样在论文中被总结为算法2:
从扩散模型中生成新图像是通过逆转扩散过程实现的:我们从T开始,从高斯分布中采样纯噪声,然后使用我们的神经网络逐渐去噪声(使用其学习到的条件概率),直到我们最终到达时间步t = 0。如上所示,我们可以通过插入均值的重参数化,使用我们的噪声预测器,得到稍微不太去噪的图像xt−1。请记住,方差是提前已知的。
理想情况下,我们得到的图像看起来像是来自真实数据分布。
下面的代码实现了这个过程。
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# 论文中的方程式11
# 使用我们的模型(噪声预测器)预测均值
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# 算法2第4行:
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 算法2(包括返回所有图像)
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# 从纯噪声开始(对于批次中的每个示例)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
请注意上面的代码是原始实现的简化版本。我们发现我们的简化版本(与论文中的算法2一致)与原始的更复杂的实现一样有效,原始实现使用了裁剪。
训练模型
接下来,我们按照常规的PyTorch方式训练模型。我们还使用上面定义的sample
方法定期保存生成的图像。
from pathlib import Path
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000
下面,我们定义模型,并将其移动到GPU上。我们还定义了一个标准的优化器(Adam)。
from torch.optim import Adam
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
让我们开始训练!
from torchvision.utils import save_image
epochs = 6
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
batch_size = batch["pixel_values"].shape[0]
batch = batch["pixel_values"].to(device)
# 算法1第3行:为批次中的每个示例均匀采样t
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 100 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
# 保存生成的图像
if step != 0 and step % save_and_sample_every == 0:
milestone = step // save_and_sample_every
batches = num_to_groups(4, batch_size)
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
采样(推理)
要从模型中进行采样,我们只需使用上面定义的样本函数:
# 采样64张图像
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)
# 显示其中一张随机图像
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
看起来模型能够生成一个不错的T恤!请记住,我们训练的数据集分辨率相对较低(28×28)。
我们还可以创建一个去噪过程的gif:
import matplotlib.animation as animation
random_index = 53
fig = plt.figure()
ims = []
for i in range(timesteps):
im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
ims.append([im])
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()
请注意,DDPM论文表明扩散模型是(非)条件图像生成的一个有前途的方向。从那以后,已经在这个领域取得了巨大的进展,尤其是在文本条件图像生成方面。下面,我们列出了一些重要的(但远非详尽)后续工作:
- 改进的去噪扩散概率模型(Nichol等,2021年):发现学习条件分布的方差(除了均值)有助于提高性能
- 级联扩散模型用于高保真度图像生成(Ho等,2021年):引入级联扩散,由多个扩散模型组成的流水线,用于生成分辨率逐渐增加的图像,以实现高保真度图像合成
- 扩散模型在图像合成方面胜过GAN(Dhariwal等,2021年):通过改进U-Net架构以及引入分类器指导,展示了扩散模型能够达到优于当前最先进生成模型的图像样本质量
- 无分类器扩散指导(Ho等,2021年):通过联合训练条件和无条件扩散模型,使用单个神经网络展示了扩散模型无需分类器即可进行指导
- 具有CLIP潜变量的分层文本条件图像生成(DALL-E 2)(Ramesh等,2022年):使用先验将文本标题转换为CLIP图像嵌入向量,然后使用扩散模型将其解码为图像
- 具有深度语言理解的照片逼真的文本到图像扩散模型(ImageGen)(Saharia等,2022年):展示了将大型预训练语言模型(例如T5)与级联扩散结合使用对于文本到图像合成效果很好
请注意,此列表仅包括截至撰写时间2022年6月7日的重要作品。
目前看来,扩散模型的主要(也许是唯一的)缺点是需要多次前向传递来生成图像(而生成模型如GAN则不需要)。然而,目前有正在进行的研究使得在仅经过10个降噪步骤后就能实现高保真度的生成。