[读论文] Generative Adversarial Text to Image Synthesis

Generative Adversarial Text to Image Synthesis
Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, Honglak Lee.

[ICML2016] arXiv: 1605.05396[2016]

Intro

前面的几篇文章介绍了几个GAN的变种,但是那些文章始终围绕着“生成高质量图像”这个topic。如何让模型按照我们“复杂的需求”生成图像呢?这就是这篇文章想要解决的问题。

这篇文章介绍了一种能够将人工编写一句描述性文本直接转换成为图像。比如,“this small bird has a short, pointy orage beak and white belly”就应该能生成下面这些图片:

这是一个看起来很fancy的工作,因为它适用范围相当广,但这个问题显然还是挺困难的,所以还没有得到非常良好的解决。

这项工作主要面临两大挑战:

  1. 学习到能够捕捉到重要的视觉细节的文本特征表达 (learn a ext feature representation that captures the important visual details)
  2. 使用这些特征来合成一些让人们误以为真的图片 (use these features to synthesize a compelling image that a human might mistake for real).

这两项挑战虽然具有难度,但幸运的是,由于近年来深度学习的兴起,这两项挑战的子问题“自然语言表达”和“图像合成”已经得到了一定程度的解决。

但是,仍然存在着一个没能很好解决的问题:按描述生成图片(即以文本描述为条件的图像概率分布)是一个非常多模态的问题,也就是说,很有可能会有很多图片都能套用到相同的解释之上。

如果将图像生成反过来、即进行“图像到文本”的caption工作,这个情况依然是个麻烦。然而由于可以使用链式法则将一个序列化的问题解构并最终让这个task可解。例如可以训练一个模型,给定这张图片和之前所有的token来预测下一个token是什么,换句话说,这个模型其实就是一个well-defined的预测问题。

作者认为可以用GAN很自然地来解决这个conditional multi-modality问题。这篇文章的主要贡献就是实现了一个简单高效的GAN架构和训练策略,使得从人工编写的描述文本合成鸟与花的图片成为可能。

Method

文章中用到的主要是在以前的文章中曾经介绍过的DCGAN模型。本文中的应用是以hybird character-level convolutional RNN encode的text feature作为输入条件。生成器\(G\)和判别器\(D\)在前向inference的时候都以这个text feature作为条件。

Network Architecture。

网络的整个框架如下:

如图所示,网络如同其他GAN一样,分为生成器\(G\)和判别器\(D\)两个部分。由于是条件GAN,所以生成器的输入不止有随机采样的噪声\(z\sim \mathcal{N}(0,1)\),更有一个text feature(即途中蓝色部分)。这个text feature是由text encoder \(\phi\) 生成的,如果text query为 \(t\) 则这个text feature就是 \(\phi(t)\)。

通常需要将描述文本使用一个全连接层压缩到一个较小维度之后(一般是128维),并使用leaky ReLU再与噪声向量\(z\)拼接在一起,作为生成器的整体的输入。

紧接着,后续的前向推断过程就是一个解卷积网络:即一张合成的图片\(\hat{x}\)通过\(\hat{x}\leftarrow G(z,\phi(t))\)被生成出来。

再来看判别器D。前面几层使用了stride-2卷积层并使用了spatial batch normalization和leaky ReLU。并且仍然像以前一样使用一个全连接层接上一个rectifier来减少描述文本embedding \(\phi(t)\)的维度。当判别器的维度是\(4\times 4\)时,将这个描述文本的embedding复制多份,并在深度上与图片进行拼接。然后在拼接之后的新tensor上继续执行一个卷积操作,然后再计算得分。这便是整个框架在inference的流程。

Matching-aware discriminator (GAN-CLS)

说完了inference,就该说说train了。该如何去train这个模型?训练条件GAN的最直接的方法就是将图片和description embedding看作是联合的样本,通过观察判别器判断“生成的图片+文本”这个整体是真的还是假的。不过这个方法有些naive:并没有在训练中给判别器提供“是否按照描述正确生成了图像”的信息。

但事实上,这种CGAN的训练过程会与非CGAN训练过程有些不同。在训练的初期,由于生成的图片大多不靠谱,所以判别器将会拒绝大部分生成的图片,这也就相当于无视了condition的存在。然而一旦生成器学习到了如何产生靠谱的图片,生成器也一定会学习到如何生成“符合条件”的图片。然后判别器也会去学习去判断生成的内容是否符合条件限制。

在这naive GAN的情况下,判别其将会观察到两种不同的输入:正确的图片并且配上了正确的文本,以及错误的图片配上了随意的文本。所以,需要分开记录这两种不同的错误来源:不真实的图片(配上任何文本),以及真实图片但条件信息匹配错误。

作者修改了GAN的训练算法来将这两种不同的错误分开。除了real/fake这两种输入以外,作者又增加了第三种输入“真实的图片配上错误的文本”,而判别器也必须要能把这种错误给区分出来。

下面就是整个训练的算法:

这个伪代码其实也是通俗易懂: 首先,将正确的和错误的文本信息encode成\(h\)和\(\hat{h}\),然后采样出一个随机噪声,然后用正确的encode和随机噪声,产生一组fake图片。然后分别计算三个不同的判别其判别的结果:真实图片正确文本、真实图片错误文本、生成图片正确文本。然后再通过后面的公式计算判别器的损失函数,进而更新判别器。然后再计算生成器的损失函数,并更新生成器。最终完成整个流程。

Learning with manifold interpolation (GAN-INT)

上面介绍的GAN-CLS方法是训练这个模型的方法之一,文中作者提出了另一种训练方法。

深度学习的目的在于学习一个良好的特征表达。一些文章证明,在embedding pairs之间的插值如果在数据流形附近,就说明这个特征选得好。所以我们可以创建大量的额外的text embedding,这些text embedding其实是通过训练集标注的text
embedding通过插值得到的。换句话说,这些插值得到的embedding是无法直接对应到人工文本标注上的,所以这一部分数据时不需要标注的。想要利用这些数据,只需要在生成器的目标函数上增加这样一项: $$ \mathbb{E}_{t_1,t_2\sim p_{\text{data}}}\left[\log(1-D(G(z,\beta t_1+(1-\beta)t_2)))\right] $$

这个公式相当于综合考虑了两个text embedding \(t_1\)和\(t_2\)的插值点。通常在实际应用中使用\(\beta=0.5\)效果就不错了。

Experiment

实验部分看着张图就好了:

后面作者还写了关于风格迁移的一个实验,不过跟文章主旨并不直接相关,所以就不展开介绍了。

Friskit

继续阅读此作者的更多文章