[读论文] Conditional Generative Adversarial Nets

Conditional Generative Adversarial Nets
Mehdi Mirza, Simon Osindero
arXiv:1411.1784 [2014]

Intro

前面我们曾经介绍过生成式对抗网络(Generative Adversarial Nets, GANs)。这个框架包括一个生成器(Generator, G)和一个判别器(Discriminator, D)两个部分。生成器输入一段随机产生的噪声,生成一张尽可能“逼真”的图片。而判别器则输入一张图片,输出判断这张图片是生成出来的还是真实的。

原作者Goodfellow在最早提出的这篇文章的最后,介绍了几个这个模型可能改进的方向。其中一个就是:

A conditional generative model \(p(x|c)\) can be obtained by adding \(c\) as input to both \(G\) and \(D\).

确实,GAN的原始模型有很多可以改进的缺点,首当其中就是“模型不可控”。从上面对GAN的介绍能够看出,模型以一个随机噪声为输入。显然,我们很难对输出的结构进行控制。例如,使用纯粹的GAN,我们可以训练出一个生成器:输入随机噪声,产生一张写着0-9某一个数字的图片。然而,在现实应用中,我们往往想要生成“指定”的一张图片。

最直观的想法就是在GAN上增加一个额外的输入。也就是说,以前我们的生成模型是\(p_g(x)\),现在,我们的生成模型是在一个条件\(c\)的控制下产生:\(p_g(x|c)\)。而这个\(c\)就是我们用来控制模型的额外的输入。

\(c\)可以是表示我们意图的一串编码,例如我们想要做0-9的手写数字生成,则\(c\)可以是一个10维的one-hot向量。则在训练过程中,我们将这些label加入到训练数据中,从而得到一个按照我们需求产生图片的生成器。

这就是Conditional Generative Adversarial Nets最基本的想法。这里要注意的是,这个\(c\)不但附加在了生成器上,同时也附加在了判别器上,相当于给了判别器一个额外的信息:现在这个图片是以条件\(c\)生成的?还是以条件\(c\)控制下的真正的图片?

Model Structure

对于GAN来说,我们训练的目标是:

$$ \mathop{\min}_{G}\mathop{\max}_{D}V(D,G)=\mathbb{E}_{\boldsymbol{x}\sim p_{\text{data}}}\left[\log D(\boldsymbol{x})\right]+\mathbb{E}_{\boldsymbol{z}\sim p_z(\boldsymbol{z})}\left[\log(1-D(G(\boldsymbol{z})))\right]. $$

而对于Conditional的GAN来说,训练目标只需要变成:

$$ \mathop{\min}_{G}\mathop{\max}_{D}V(D,G)=\mathbb{E}_{\boldsymbol{x}\sim p_{\text{data}}}\left[\log D(\boldsymbol{x}|\boldsymbol{y})\right]+\mathbb{E}_{\boldsymbol{z}\sim p_z(\boldsymbol{z})}\left[\log(1-D(G(\boldsymbol{z}|\boldsymbol{y})|\boldsymbol{y}))\right]. $$

(原文中的公式有误,后面一项的判别器D中忘了加以y为条件的概率)

其实这个改动形象一些表示就是将原来只接受一个输入\(z\)的生成器变成接受两个输入(\(z\)和\(y\)),将原来只接受一个输入\(x\)的判别器变成接受两个输入(\(x\)和\(y\))。再具体一些,就是下面这张图:

右边绿色的部分就是条件,记为\(y\)。

Experiment

这个模型的思想还是很简单的,基本上一句话就讲明白了。紧接着就是实验部分了。

实验分成了两种,一种是单模态,一种是多模态。

单模态的试验以MNIST为数据集,控制输入\(y\)是label的one-hot表示。效果如图

多模态的实验比较复杂,作者做了一个图像自动标注。在这个Conditional GAN的框架下,模型有好多个部分。首先在完整的ImageNet数据集(21,000个label)上训练了一个卷积模型作为特征提取器。对于词语表达(原文中是world representation,个人认为是笔误,应该是word representation),作者使用YFCC100M数据集中的user-tags, titles和descriptions,利用skip-gram训练了一个200维的词向量。训练中忽略了词频小于200的词,最终词典大小是247465。

在实验过程中,固定了这个卷积模型和词向量。生成器的输入\(z\)为噪声,输入\(y\)为经过卷积网络提取后得到的图片的特征(文中称是一个经卷积展开全连接隐层产生的4096维的向量)。判别器输入\(x\)为一个词向量,判别器的功能是判断这个词向量是否是对图片的正确标注。

实验效果如下:

从实验结果来看,生成的标签还是很靠谱的。

Friskit

继续阅读此作者的更多文章