在TensorFlow中使用GAN生成图像
介绍
在本文中,我们将探讨在TensorFlow中使用GAN生成独特手写数字的应用。GAN框架由生成器和鉴别器两个关键组件组成。生成器以随机方式生成新的图像,而鉴别器则用于区分真实和伪造的图像。通过GAN训练,我们可以获得一系列与手写数字密切相似的图像。本文的主要目标是概述使用MNIST数据集构建和评估GAN的过程。
学习目标
- 本文全面介绍了生成对抗网络(GANs)并探索了它们在图像生成中的应用。
- 本教程的主要目标是指导读者逐步构建使用TensorFlow库的GAN的过程。它涵盖了在MNIST数据集上训练GAN以生成新的手写数字图像的过程。
- 本文讨论了GAN的架构和组件,包括生成器和鉴别器,以增强读者对其基本工作原理的理解。
- 为了帮助学习,本文包含了代码示例,演示了各种任务,如读取和预处理MNIST数据集,构建GAN架构,计算损失函数,训练网络和评估结果。
- 此外,本文还探讨了GAN的预期结果,即一系列与手写数字非常相似的图像。
本文作为数据科学博客马拉松的一部分发表。
我们要构建什么?
使用现有图像数据库生成新颖的图像是一种称为生成对抗网络(GANs)的专门模型的显著特点。GAN在利用各种图像数据集生成无监督或半监督图像方面表现出色。
本文利用GAN的图像生成潜力创建手写数字。该方法涉及对手写数字数据库进行网络训练。在这篇教程中,我们将使用Tensorflow库构建一个基本的GAN,对MNIST数据集进行训练,并生成新的手写数字图像。
我们如何设置这个?
本文的主要重点是利用GAN的图像生成潜力。该过程始于加载和预处理图像数据库,以便促进GAN训练过程。一旦数据成功加载,我们就开始构建GAN模型并开发必要的训练和测试代码。在接下来的部分,我们提供了详细的说明,介绍了实现这个功能并使用MNIST数据库生成新图像的详细说明。
模型构建
我们要构建的GAN模型由两个重要组件组成:
- 生成器:负责生成新的图像。
- 鉴别器:评估生成图像的质量。
我们将开发的生成图像的一般架构如下图所示。下面的部分简要描述了如何读取数据库、创建所需的架构、计算损失函数和训练网络。此外,提供了用于检查网络和生成新图像的代码。
读取数据集
MNIST数据集在计算机视觉领域具有重要意义,由大量28×28像素的手写数字组成。由于其灰度、单通道图像格式,该数据集非常适合我们的GAN实现。
下面的代码片段演示了使用Tensorflow的内置函数加载MNIST数据集。成功加载后,我们将对图像进行归一化和重塑,以便在GAN架构中高效处理2D图像数据。此外,为训练和验证数据分配内存。
每个图像的形状被定义为一个28x28x1矩阵,其中最后一个维度表示图像中的通道数。由于MNIST数据集包含灰度图像,我们只有一个通道。
在这个特定的例子中,我们将潜在空间的大小(表示为“zsize”)设置为100。根据具体要求或偏好,可以调整此值。
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
import matplotlib.pyplot as plt
import sys
import numpy as np
num_rows = 28
num_cols = 28
num_channels = 1
input_shape = (num_rows, num_cols, num_channels)
z_size = 100
(train_ims, _), (_, _) = mnist.load_data()
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
定义生成器
生成器(D)在生成对抗网络(GAN)中起着至关重要的作用,它负责生成能够欺骗判别器的逼真图像。它是GAN中图像生成的主要组件。在本研究中,我们使用了一种特定的生成器架构,它包含一个全连接(FC)层,并使用了Leaky ReLU激活函数。然而,值得注意的是,生成器的最后一层使用了Tanh激活函数,而不是LeakyReLU。这个调整是为了确保生成的图像与原始的MNIST数据库在相同的区间(-1, 1)内。
def build_generator():
gen_model = Sequential()
gen_model.add(Dense(256, input_dim=z_size))
gen_model.add(LeakyReLU(alpha=0.2))
gen_model.add(BatchNormalization(momentum=0.8))
gen_model.add(Dense(512))
gen_model.add(LeakyReLU(alpha=0.2))
gen_model.add(BatchNormalization(momentum=0.8))
gen_model.add(Dense(1024))
gen_model.add(LeakyReLU(alpha=0.2))
gen_model.add(BatchNormalization(momentum=0.8))
gen_model.add(Dense(np.prod(input_shape), activation='tanh'))
gen_model.add(Reshape(input_shape))
gen_noise = Input(shape=(z_size,))
gen_img = gen_model(gen_noise)
return Model(gen_noise, gen_img)
定义判别器
在生成对抗网络(GAN)中,判别器(D)通过评估图像的真实性和可能性,对真实图像和生成图像进行区分,起着关键的任务。这个组件可以看作是一个二分类问题。为了解决这个任务,我们可以使用一个简化的网络架构,包括全连接层(FC),Leaky ReLU激活函数和Dropout层。值得注意的是,判别器的最后一层包含一个FC层,后面是Sigmoid激活函数。Sigmoid激活函数产生所需的分类概率。
def build_discriminator():
disc_model = Sequential()
disc_model.add(Flatten(input_shape=input_shape))
disc_model.add(Dense(512))
disc_model.add(LeakyReLU(alpha=0.2))
disc_model.add(Dense(256))
disc_model.add(LeakyReLU(alpha=0.2))
disc_model.add(Dense(1, activation='sigmoid'))
disc_img = Input(shape=input_shape)
validity = disc_model(disc_img)
return Model(disc_img, validity)
计算损失函数
为了确保GAN中良好的图像生成过程,确定评估其性能的适当指标是非常重要的。通过损失函数来定义该参数。
判别器负责将生成的图像划分为真实或假的,并给出真实的概率。为了实现这个区别,判别器在给出真实图像时会最大化函数D(x),在给出假图像时会最小化D(G(z))。
另一方面,生成器的目标是通过创建一个可以被错误解释的逼真图像来欺骗判别器。从数学上讲,这涉及到缩放D(G(z))。然而,仅仅依靠这个组件作为损失函数可能会导致网络对错误结果过于自信。为了解决这个问题,我们使用损失函数的对数(D(G(z)))。
生成图像的 GAN 的整体代价函数可以被表达为一个极小极大博弈:
min_G max_D V(D,G) = E(xp_data(x))(log(D(x))] + E(zp(z))(log(1 – D(G(z)))])
这样的 GAN 训练需要一个微妙的平衡,并且可以看作是两个对手之间的比赛。双方都试图通过玩 MinMax 游戏来影响和超过对方。
我们可以使用二元交叉熵损失来实现生成器和判别器。
对于生成器和判别器的实现,我们可以利用二元交叉熵损失。
# 判别器
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
z = Input(shape=(z_size,))
# 生成器
img = generator(z)
disc.trainable = False
validity = disc(img)
# 组合模型
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')
优化损失
为了促进网络的训练,我们的目标是让 GAN 参与到一个极小极大博弈中。这个学习过程通过使用梯度下降来优化网络权重。为了加速学习过程并防止收敛到次优的损失空间,采用随机梯度下降 (SGD)。
鉴于判别器和生成器有不同的损失,单一的损失函数无法同时优化两个系统。因此,为每个系统使用单独的损失函数。
def intialize_model():
disc= build_discriminator()
disc.compile(loss='binary_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
generator = build_generator()
z = Input(shape=(z_size,))
img = generator(z)
disc.trainable = False
validity = disc(img)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer='sgd')
return disc, Generator, and combined
在指定所有所需特性之后,我们可以训练系统并优化损失。训练 GAN 生成图像的步骤如下:
- 加载图像并生成与加载图像大小相同的随机声音。
- 区分上传的图像和生成的声音,并考虑真实或虚假的可能性。
- 产生另一个相同大小的随机噪声,并作为输入提供给生成器。
- 训练生成器一段特定时间。
- 重复这些步骤,直到图像令人满意。
def train(epochs, batch_size=128, sample_interval=50):
# 加载图像
(train_ims, _), (_, _) = mnist.load_data()
# 预处理
train_ims = train_ims / 127.5 - 1.
train_ims = np.expand_dims(train_ims, axis=3)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# 训练循环
for epoch in range(epochs):
batch_index = np.random.randint(0, train_ims.shape[0], batch_size)
imgs = train_ims[batch_index]
# 创建噪声
noise = np.random.normal(0, 1, (batch_size, z_size))
# 使用生成器预测
gen_imgs = gen.predict(noise)
# 计算损失函数
real_disc_loss = disc.train_on_batch(imgs, valid)
fake_disc_loss = disc.train_on_batch(gen_imgs, fake)
disc_loss_total = 0.5 * np.add(real_disc_loss, fake_disc_loss)
noise = np.random.normal(0, 1, (batch_size, z_size))
g_loss = full_model.train_on_batch(noise, valid)
# 每隔几个周期保存输出
if epoch % sample_interval == 0:
one_batch(epoch)
生成手写数字
使用 MNIST 数据集,我们可以创建一个实用函数,使用生成器生成一组图像的预测结果。这个函数生成一个随机的声音,提供给生成器,运行它来显示生成的图像,并将其保存在一个特殊的文件夹中。建议定期运行这个实用函数,例如每 200 个周期,以监视网络的进展情况。实现如下:
def one_batch(epoch):
r, c = 5, 5
noise_model = np.random.normal(0, 1, (r * c, z_size))
gen_images = gen.predict(noise_model)
# 将图像重新缩放为 0 - 1
gen_images = gen_images*(0.5) + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_images[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
在我们的实验中,我们使用批量大小为32,大约训练了10,000个epoch的GAN。为了跟踪训练的进展,我们每200个epoch保存一次生成的图像,并将它们存储在一个名为“images”的指定文件夹中。
disc, gen, full_model = intialize_model()
train(epochs=10000, batch_size=32, sample_interval=200)
现在,让我们来看看不同阶段的GAN模拟结果:初始化、400个epoch、5000个epoch以及10000个epoch的最终结果。
最初,我们将随机噪声作为输入传递给生成器。
经过400个epoch的训练,我们可以观察到一些进展,尽管生成的图像仍然与真实数字有明显差异。
经过5000个epoch的训练,我们可以观察到生成的图形开始类似于MNIST数据集。
完成整个10,000个epoch的训练,我们得到了以下输出。
这些生成的图像与用于训练网络的手写数字数据非常相似。重要的是要注意,这些图像不是训练集的一部分,完全由网络生成。
下一步
既然我们在GAN的图像生成方面取得了良好的结果,我们可以进一步改进它。在本讨论的范围内,我们可以考虑尝试不同的参数。以下是一些建议:
- 尝试不同的潜在空间变量z_size的值,看是否能提高效率。
- 将训练epoch数增加到10,000以上。将训练时间加倍或三倍可能会展现出改进或降低的结果。
- 尝试使用不同的数据集,如时尚MNIST或移动MNIST。由于这些数据集与MNIST具有相同的结构,我们可以调整现有代码。
- 考虑尝试其他架构,如CycleGun、DCGAN等。修改生成器和判别器函数可能足以探索这些模型。
通过实施这些改变,我们可以进一步增强GAN的能力,并在图像生成方面探索新的可能性。
这些生成的图像与用于训练网络的手写数字数据非常相似。这些图像不是训练集的一部分,完全由网络生成。
结论
总之,GAN是一种强大的机器学习模型,能够基于现有数据库生成新的图像。在本教程中,我们展示了如何使用Tensorflow库设计和训练一个简单的GAN,并使用了MNIST数据库作为示例。
要点
- GAN由两个重要组件组成:生成器负责从随机输入生成新的图像,判别器旨在区分真实和伪造的图像。
- 通过学习过程,我们成功创建了一组非常类似于手写数字的图像,如示例图像所示。
- 为了优化GAN的性能,我们提供了匹配度指标和损失函数,帮助区分真实和伪造的图像。通过在未见数据上评估GAN并使用生成器,我们可以生成新的、以前未见过的图像。
- 总的来说,GAN在图像生成方面提供了有趣的可能性,并在机器学习和计算机视觉等多个领域具有巨大潜力。
常见问题
本文中显示的媒体不归Analytics Vidhya所有,仅由作者自行决定使用。