再读WGAN

最近看了台湾大学李宏毅老师关于深度学习的系列教程(Machine Learning and Having It Deep and Structured),收获颇多。教程主要介绍了深度学习的基本知识,并且介绍了大量比较新的技术,例如Seq2Seq、Attention、Generative Model、Language Model等等。与以往其他“以公式为中心”的教程不同,这个教程从“定性”的角度介绍了这些技术,能让我们详细了解到这些技术存在的合理性。同时又有着详细的公式推导,并且将事实之间的逻辑关系梳理的非常清晰。

上一篇文章主要介绍了原生GAN的基本内容。而这篇文章将会主要介绍GAN的一个增强的变体:WGAN。


复习一下GAN

开讲之前首先我们复习一下GAN。

在应用中我们有一个需求:让机器生成数据。例如让机器通过看大量的图片,从而学习绘画。实现这个功能的方法有很多,无非是构造一个生成模型,然后用一个loss function去指导生成器让它能生成高质量的图片。采用不同的模型、不同的loss function就能得到不同的方法。而GAN则是将“生成数据”交给一个被称为生成器的神经网络进行、将“甄别数据真伪”的工作交给一个被称为判别器的另一个神经网络进行。

一般来说,我们用\(P_{data}\)来描述真实数据的分布。即每一个真实数据都应该是从该分布中采样得到的。“更真实”的数据被采样得到的几率要大一些,例如Gaussian分布的中心,“不那么真实”的数据被采样得到的几率要小一些。

拿李宏毅老师Slides里的图示来说明:

整个空间是整个图像空间,Task是生成动漫人脸。整个图像空间中的蓝色部分是真实人脸的数据分布。也就是说,这个分布内对应的图像的向量应该是“真实”的。而区域外的图像向量,虽然也是一张图像,但它们看上去并不真实。反过来说,当我们有一个图像向量\(x\),如果它真的是一张“合理”的动漫人物人脸,那么\(P_{data}(x)\)应该会比较高。反之则会比较低。

生成器\(G\)是一个神经网络。这个神经网络略有不同。传统的神经网络以“输入一个x,得到一个y为目标”,而这个生成器输入实际上是一个分布,而输出是另一个分布。(有些疑问对不对?一个神经网络如何输入一个分布而不是一个向量或者其他什么东西呢?我们后面会介绍通过采样的方式来实现这个效果。)。

如果说生成器的输出是一个由\(\theta_g\)控制的概率分布(由一个神经网络形成),那么我们完全可以通过控制参数\(\theta_g\)来让\(P_G\)尽可能与\(P_{data}\)接近!这样我们就能用\(P_G\)来代替\(P_{data}\)成为一个能够生成真实数据的生成模型!

但是,如果我们的生成器的输入是一个分布,输出是经过变换的另一个分布的时候,我们无法再设计输入一个数据的本身\(x\)来计算\(P_G(x)\)。即,我们可以输入一个\(P_{prior}(x)\),输出一个\(P_G(x)\)(它是一个随机变量x的分布),但我们没法输入\(x\)本身(它是一个样本),并且得到\(x\)的这个数值。

但是实际上,我们并不需要真的知道\(P_G(x)\)的表达式,也不需要计算给定一个\(x\)时\(P_G(x)\)的值就能完成我们想要做的任何事!

在训练过程中,我们需要同时训练一个判别器,让判别器识别输入的数据是“真实数据”?还是由生成器生成的“伪造数据”?在训练过程中由于我们知道输入数据到底是真是假,所以我们可以设计一个损失函数来让判别器尽可能完美地识别出数据真伪。这个函数能够用来衡量生成器所产生的分布\(P_G\)和真实数据分布\(P_{data}\)之间的距离,在原生的GAN中所衡量的是JS Divergence。判别器训练的目标,是在确定生成器的情况下,尽可能地去让判别器自身能够将这个距离拉大————在训练时我们既然知道生成分布和数据分布不一样,那我们就应该让判别器认为这两者是天壤之别,也就能得到一个最有效的判别器!

另一方面,当我们得到了当前状态下最有效的判别器时,我们想要以这个判别器为准绳,让我们的生成器能够强大到足以骗过判别器。在这一来二去之间,对抗学习就会慢慢收敛。

那么判别器和生成器随着训练会发生怎样的变化呢?我们用下面几张草图解释一下。图中横轴表示样本\(x\),红色线表示判别器\(D(x)\),绿色线表示真实数据分布\(P_{data}\),蓝色线表示生成器产生的分布\(P_G\):

训练伊始,生成分布在左边,真实数据在右边。对于生成器来说,它的终极目标是移动自身(通过调整自身参数)让自己与绿色线重叠。训练刚开始的时候,由于生成分布与数据分布完全不同,所以判别器能够轻而易举地判断出来整个\(x\)轴上哪些部分是“真的”,那些部分是“假的”。例如对于靠右侧的部分,大部分都是真实数据,且与生成数据之间重叠较少,所以对于判别器来说\(D(x)\)在这些位置很倾向于给出高概率(\(D(x)\)高就意味着判别器认为这里是真实的)。

同时我们再仔细看看这张图蓝色绿色同时存在的地方。由于这些地方既没有生成数据,又没有真实数据,对于判别器来说,它无法学习到究竟该做出什么判断,于是倾向于给出一个模棱两可的答案:\(D(x)\approx0.5\)

由于判别器能很轻易地分辨出样本是不是生成的,所以对于生成器来说,它的训练任务就是通过让自己向红箭头指向的右边移动,然后

哎呀,移动多了,类似地,下次移动看来是要向左了。这里我们发现真实分布和生成分布开始有重叠了。对于重叠部分来说,如果真实分布多一些的话,那\(D(x)\)就倾向于给一个高值,反之\(D(x)\)就倾向于给个低值。而相等的部分,判别器也分不出真假,那就给个模棱两可的值:\(D(x)\approx0.5\)。既然训练还有继续进行的空间,那就继续训练生成器:

这时候生成分布已经非常接近真实分布了。所以说\(D(x)\)不太容易产生接近于\(1\)的值了,因为对于判别器来说,确实不太容易分辨得出。

当训练到完美的程度,生成分布完全模仿了数据的真实分布。这时候对于\(D(x)\)来说,无论\(x\)取值到什么位置,它来源于生成分布、真实分布的概率各占一半。此时\(D(x)\)在数轴上处处等于0,模型收敛。

这里有一个视频,非常好地可视化了训练过程。视频来自YouTube,需要翻墙。

再回顾一下GAN的训练过程:

  • 初始化D的\(\theta_d\)和G的\(\theta_g\)
  • 在每一个训练循环进行:
    • 从数据分布\(P_{data}(x)\)中采样\(m\)个样本\(\{x^1, x^2, \dots, x^m\}\)
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1, z^2, \dots, z^m\}\)
    • 将这些噪声样本输入生成器\(G\),得到生成样本\(\{\tilde{x}^1,\tilde{x}^2,\dots,\tilde{x}^m\}, \tilde{x}^i=G(z^i)\)
    • 更新判别器的参数\(\theta_d\),即最大化:
      • \(\tilde{V}=\frac{1}{m}\sum_{i=1}^m\log D(x^i)+\frac{1}{m}\sum_{i=1}^m\log(1-D(\tilde{x}^i))\)
      • \(\theta_d\leftarrow\theta_d+\eta\nabla\tilde{V}(\theta_d)\)
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1,z^2,\dots,z^m\}\)
    • 更新生成器的参数\(\theta_g\),即最小化:
      • \(\require{cancel}\tilde{V}=\cancel{\frac{1}{m}\sum_{i=1}^m\log D(x^i)}+\frac{1}{m}\sum_{i=1}^m\log(1-D(G(z^i)))\)
      • \(\theta_g\leftarrow\theta_g-\eta\nabla\tilde{V}(\theta_g)\)

上面公式中更新参数\(\theta_g\)时需要的优化目标\(\tilde{V}\)被划掉是因为它内部没有参数\(\theta_g\),所以并不会对更新\(\theta_g\)有帮助。

Divergence的大一统

f-divergence

如果大家了解过机器学习的一些算法,肯定听说过一个叫做广义线性模型(GLM, generalized linear model)的东西。它将很多模型(如逻辑回归,线性回归等)都归纳到了一个统一的框架之下。每一种模型都是通用结构下的一个特例。

广义线性模型的核心叫做指数分布族。它又将各种分布大一统地表示出来。各种我们常见的分布(如伯努利分布、高斯分布等)都可以被统一到指数分布族这个框架之中。

在自然科学领域,科学家们的“一大乐事”就是寻找一个大一统的统一模型,将现有的一些东西看作是由模型生成的特例。

之前我们曾经介绍过。原始GAN采用的是JS divergence来衡量两个分布之间的距离。除此之外这个世界上还存在着各种各样的divergence,例如KL divergence、Reverse KL divergence。那么这些divergence之间是否具有什么“统一”的模式呢?

事实上真的有这样的统一模式的存在。它就是f-divergence。

假设\(P\)和\(Q\)是两个分布。\(p(x)\)和\(q(x)\)是对应样本\(x\)的概率,则有:

$$ D_f(P||Q)=\int_xq(x)f(\frac{p(x)}{q(x)})dx
$$

就是f-divergence,其中\(f\)要求满足以下两个约束:(1)是凸函数。(2) \(f(1)=0\)。

如同KL divergence和JS divergence一样,\(D_f(P||Q)\)能够评价\(P\)和\(Q\)之间的差异。

为什么说\(D_f(P||Q)\)可以作为一种距离的度量呢?因为假设对任意\(x\)都有\(p(x)=q(x)\):

$$ D_f(P||Q)=\int_xq(x)\underbrace{f(\overbrace{\frac{p(x)}{q(x)}}^{=1})}_{=0}dx=0
$$

由于\(p(x)=q(x)\)所以\(\frac{p(x)}{q(x)}=1\),而之前我们要求过\(f(1)=0\),所以此时\(D_f(P||Q)=0\),即当两个分布相同时f-divergence为\(0\)

想要让f-divergence能表示两个分布的距离,不但要求当两个分布相同时距离是\(0\),还需要保证这个\(0\)是f-divergence能取到的最小值。如何证明?

由于\(f(x)\)是凸函数,则有下面的不等式:

$$ \require{cancel} \begin{align} D_f(P||Q)&=\int_xq(x)f(\frac{p(x)}{q(x)})dx \\
&\ge f(\int_x\cancel{q(x)}\frac{p(x)}{\cancel{q(x)}}dx)=f(1)=0 \end{align} $$

所以说,\(D_f(P||Q)\)总是大于等于\(0\)的。

如此我们就可以认为f-divergence就能够作为衡量两个分布之间距离的度量(相等量分布f-divergence为\(0\),f-divergence永远大于等于\(0\))。

如果我们将\(f(x)\)指定为各种不同的函数,我们就将能够得到不同的xxx-divergence。

例如KL divergence:

$$ \begin{align} f(x)&=x\log x \\
D_f(P||Q)&=\int_x q(x)\frac{p(x)}{q(x)}\log(\frac{p(x)}{q(x)})dx=\int_xp(x)\log(\frac{p(x)}{q(x)})dx
\end{align} $$

例如Reverse KL divergence:

$$ \begin{align} f(x)&=-\log x \\
D_f(P||Q)&=\int_xq(x)(-\log(\frac{p(x)}{q(x)}))dx=\int_xq(x)\log(\frac{q(x)}{p(x)})dx
\end{align} $$

再例如Chi Square:

$$ \begin{align} f(x)&=(x-1)^2 \\
D_f(P||Q)&=\int_xq(x)(\frac{p(x)}{q(x)}-1)^2dx=\int_x\frac{(p(x)-q(x))^2}{q(x)}dx
\end{align} $$

Fenchel Conjugate

下面介绍另一个概念:Fenchel Conjugate。即每个凸函数\(f\)都具备一个conjugate函数\(f^*\):

$$ f^*(t) = \max_{x\in \mathbf{dom}(f)}\{xt-f(x)\}
$$

这个函数是在干啥呢?直接去想象这个函数的样子可能会有些困难,那我们先尝试计算在给定一个值的情况下会发生什么事情。假设我们给定\(t=t_1\)。即:

$$ f^*(t_1)=\max_{x\in\mathbf{dom}(f)}\{xt_1-f(x)\}
$$

也就是说固定了\(t\)的值之后,每从\(f\)的定义域中找到一个\(x\)之后,将其带入到\(f^*(t_1)\)之后都能得到一个值:\(\max_{x\in\mathbf{dom}(f)}\{xt_1-f(x)\}\)

也就是说,不同的\(x\)会有不同的\(f^*(x)\)的值。不同的值之间会有大小之分。我们可以用下图来表示:

图中的红线表示当\(t=t_1\)时,不同的\(x_i\)能取到的不同的\(x_it_1-f(x_i)\)。其中取值最大的点(对应\(x=x_1\))就是\(f^*(t_1)\)对应的点。

类似地,当\(t=t_2\)时,又能画出这样一条竖线:不同的\(x_i\)对应了不同的\(x_it_2-f(x_i)\),其中最大的那个就是\(f^*(t_2)\)(下图对应\(x=x_3\)的那个点):

我们还可以从另一个角度去看待这个问题。如果我们将\(t\)看作是一个自变量,\(x_i\)看作是一个参数,那么每一个不同的\(x_i\)都将会对应到一条直线(\(x_it-f(x_i)\)是一条直线)。则\(f^*(t)\)是一条曲线,曲线上的每一点都是对应所有直线中在这个位置上取值最大的点,就好像形成一个包络面一样:

下面我们看一个具体一些的例子

当\(f(x)=x\log x\)时,我们可以将对应的\(f^*(t)\)画出来:

这个图看上去好像像个指数函数?没错,当\(f(x)=x\log x\)时,\(f^*(t)=\exp(t-1)\)。为什么呢?

由于\(f^*(t)=\max_{x\in\mathbf(f)}\{xt-f(x)\}\),假设让\(g(x)=xt-x\log x\),那么现在的问题就变成了:给定一个\(t\)时,求\(g(x)\)的最大值问题。对\(g(x)\)求导并让导数为\(0\):\(\frac{dg(x)}{dx}=t-\log x-1=0\),可解得\(x=\exp(t-1)\)。再带入回原公式可得:\(f^*(t)=\exp(t-1)\times t-\exp(t-1)\times(t-1)=exp(t-1)\)

除此之外,关于Fenchel Conjugate还有两个性质:

  1. 所有的凸函数\(f\)都有一个conjugate函数\(f^*\)
  2. \(((f^*)^*)=f\)
Connection with GAN

前面说了这么多,跟GAN有什么关系呢?当然有关系!

首先,我们将之前式子里得记号换一换:

$$ f^*(t)=\sup_{x\in\mathbf{dom}(f)}\{xt-f(x)\} \leftrightarrow f(x)=\max_{t\in\mathbf{dom}(f^*)}\{xt-f^*(t)\}
$$

实际上就是将原公式中的\(f^*\)与\(f\)互换、将\(x\)与\(t\)互换。

改写了公式之后,我们可以将改写之后的\(f(x)\)当作是f-divergence里面的那个\(f(x)\),输入到\(f\)中的自变量\(x=\frac{p(x)}{q(x)}\)。这样公式就会变成:

$$ \begin{align} D_f(P||Q)&=\int_xq(x)f\left(\frac{p(x)}{q(x)}\right)dx \\
&=\int_xq(x)\left(\max_{t\in\mathbf{dom}(x^*)}\left\{\frac{p(x)}{q(x)}t-f^*(t)\right\}\right)dx \end{align} $$

然后我们此时可以构建一个函数\(D(x)\in\mathbf{dom}(f^*)\),即输入\(x\),而输出则是一个与\(t\)“同等地位”的值。这样我们就可以用\(D(x)\)来替代\(t\)。但是这种替换对原来的公式并不是等价的。因为我们所能用\(D(x)\)找到的\(t\)并不是那个能够让\(f\)最大的那个\(t\)。所以我们替换之后构造的函数永远要小于等于f-divergence:

$$ \begin{align} D_f(P||Q)&\ge\int_xq(x)\left(\frac{p(x)}{q(x)}D(x)-f^*(D(x))\right)dx \\
&=\underbrace{\int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx}_{M} \end{align} $$

这就相当于我们找到了一个\(D_f(P||Q)\)的下界。接下来,如果我们能找到一个让上面公式中\(M\)部分最大的\(D\),那么如果我们在公式中采用了这个找到的\(D\),那就可以去逼近\(D_f(P||Q)\),即:

$$ \begin{align} D_f(P||Q)&\approx\max_D\int_xp(x)D(x)dx-\int_xq(x)f^*(D(x))dx \\
&=\max_D\left\{\underbrace{E_{x\sim P}[D(x)]}_{\text{Samples from P}}-\underbrace{E_{x\sim Q}[f^*(D(x))]}_{\text{Samples from Q}}\right\} \end{align} $$

上面公式的第二行是将前面的概率积分变成了求期望。然而我们在工程上没法真的做期望,所以一般的做法就是分别从\(P\)和\(Q\)中去sample数据。

假设现在我们的\(P\)是\(P_{data}\),\(Q\)是\(P_G\),则公式变成:

$$ D_f(P_{data}||P_G)=\max_D\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\}
$$

这样我们就能得到\(P_{data}\)和\(P_G\)二者的f-divergence了!然而写到这里读者肯定会问,我们之前不是曾经写出来过f-divergence的公式了嘛?(\(\int_xq(x)f(\frac{p(x)}{q(x)})dx\))。但实际上想要将f-divergence用在GAN中是不可能的,因为我们并不知道\(P_{data}\)和\(P_g\)的表达式。

所以我们这一路的公式推导,最终得到的公式只需要我们简单地从\(P_{data}\)和\(P_g\)中采样就可以计算得到f-divergence。而不再需要直到这两个分布的表达式即可计算。

接下来,假如我们想要去寻找一个能让这个距离最小的\(P_G\),则这个\(G\)应该是:

$$ \begin{align} G^*&=\arg\min_GD_f(P_{data}||P_G) \\
&=\arg\min_G\max_D\{E_{x\sim P_{data}}[D(x)]-E_{x\sin P_G}[f^*(D(x))]\} \\ &= \arg\min_G\max_DV(G, D) \end{align} $$

实际上这就是我们之前的GAN的目标函数。而公式中的\(f^*\)就是生成器。但实际上要注意,此处的\(V(G,D)\)不一定就是原生GAN的形式。事实上,在将其应用到GAN时,有非常多的不同的生成器函数可供使用:

在这个体系下,你只需要从这个表格中挑选不同的\(f^*(t)\)就可以得到不同的f-divergence。而公式中的\(D(x)\),就是我们GAN中的判别器。

WGAN

前面我们介绍了使用f-divergence来将“距离”定义到一个统一框架之中的方法。而Fenchel Conjugate则将这个f-divergence与GAN联系在一起。这么做的目的在于,我们只要能找到一个符合f-divergence要求的函数,就能产生一个距离的度量,从而定义一种不同的GAN。

对于原生的GAN来说,选择特定的度量函数之后,会导致目标函数变成生成分布与真是分布的JS divergence。但是这个divergence有很多问题。比如说一个最严重的问题就是当两个分布之间完全没有重叠时,分布间距离的大小并不会直接反映在divergence上。这对基于迭代的优化算法是个致命问题。所以后面就有人研究了WGAN这个基于Earth Mover's Distance的方法。

标配版的WGAN(使用weight clipping)

Earth Mover's Distance

用一句话描述EM距离:将一个分布\(P\)通过搬运的方式变成另一个分布\(Q\)所需要的最少搬运代价。

比如说我们有下面的两个分布:

如何将\(P\)上的内容“匀一匀”得到\(Q\)呢?比如说把最高的哪一条分开一部分分到其他地方?这或许是一种解决方案:

但是显然除此之外还有很多种方法,例如:

既然移动的方法有很多种,如果每一种都表示了一种代价,那么显然有“好”方法,就会有“坏”方法。假设我们衡量移动方法好坏的总代价是“移动的数量”\(x\)“移动的距离”。那这两个移动的方案肯定是能分出优劣的。

当我们用分布\(Q\)上不同颜色的色块对应分布\(P\)的相应位置,就可以将最好的移动方案画成下面这个样子:

为了便于形式化定义,我们可以将这个变化画为一个矩阵:

对于每一个移动方案\(\gamma\),都能有这样一个矩阵。矩阵的每一行表示分布\(P\)的一个特定位置。该行中的每一列表示需要将该行的内容移动到分布\(Q\)对应位置的数量。即矩阵中的一个元素\((x_p, x_q)\)表示从\(P(x_p)\)移动到\(Q(x_q)\)的数量。

而对于方案\(\gamma\)我们可以定义一个平均移动距离(Average distance of a plan \(\gamma\)):

$$ B(\gamma)=\sum_{x_p,x_q}\gamma(x_p,x_q)||x_p-x_q||
$$

而Earth Mover's Distance就是指所有方案中平均移动距离最小的那个方案:

$$ W(P,Q)=\min_{\gamma\in\prod}B(\gamma)
$$

其中\(\prod\)是所有可能的方案。

为什么说这个EM距离比较好呢?因为它没有JS Divergence的问题。比如说,当第\(0\)次迭代时,两个分布的样子是这样:

这个时候JSD:\(JS(P_{G_0},P_{data})=\log 2\)

当我们的训练继续,到第\(50\)次迭代时,在理想情况下两个分布应该靠近了一些:

这个时候JSD:\(JS(P_{G_{50}}, P_{data})=\log 2\)

再继续训练,当第100次迭代时,假如说\(P_{G_{100}}\)已经与\(P_{data}\)完全重合:

此时JSD:\(JS(P_{G_{100}}, P_{data})=0\)

从上面的训练过程中能看出来迭代过程中JSD总是不变的(永远是\(\log 2\)),直到两个分布重叠的一瞬间,JSD降为\(0\)。

而当我们换成EM距离的时候,即便在两次迭代中两个分布完全没有重叠,但一定有EM距离上的区别。

即:

$$ \begin{align} W(P_{G_0},P_{data})&=d_0 \\
W(P_{G_{50}}, P_{data})&=d_{50} \\
W(P_{G_{100}}, P_{data})&=d_{100} \\
\end{align} $$

与GAN的整合

前面我们介绍了EM距离,接下来我们就将EM距离与GAN联系起来!

回忆一下f-divergence:

$$ D_f(P_{data}||P_G)=\max_D\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[f^*(D(x))]\}
$$

而WGAN的文章中写到,EM距离也可以类似f-divergence,用一个式子表示出来:

$$ W(P_{data},P_G)=\max_{D\in \text{1-Lipschitz}}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]\}
$$

公式中\(\text{1-Lipschitz}\)表示了一个函数集。当\(f\)是一个Lipschitz函数时,它应该受到以下约束:\(||f(x_1)-f(x_2)||\le K||x_1-x_2||\)。当\(K=1\)时,这个函数就是\(\text{1-Lipschitz}\)函数。

这个约束有啥用?直观来说,就是让这个函数的变化“缓慢一些”:

图中绿色的线属于\(\text{1-Lipschitz}\)函数,而蓝色的线肯定不是\(\text{1-Lipschitz}\)函数。

为什么要限制生成器\(D\)时\(\text{1-Lipschitz}\)函数呢?我们线考虑一下如果不限制它是\(\text{1-Lipschitz}\)函数时会发生什么。

假设我们现在有两个一维的分布,\(x_1\)和\(x_2\)的距离是\(d\),显然他们之间的EM距离也是\(d\):

此时如果我们想要去优化\(W(P_{data},P_G)=\max_{D\in \text{1-Lipschitz}}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]\}\),只需要让\(D(x_1)=+\infty\),而让\(D(x_2)=-\infty\)就可以了。

也就是说,如果不加上\(\text{1-Lipschitz}\)的限制的话,只需要让判别器判断\(P_{data}\)时大小是正无穷,判断\(P_G\)时是负无穷就足够了。这样的判别器可能会导致训练起来非常困难:判别器区分能力太强,很难驱使生成器让生成分布fit数据分布(验钞机太高明,假币无论怎么造假也骗不过验钞机,假币就训练不动了)。

这个时候我们加上了这个限制,也就是说\(||D(x_1)-D(x_2)||\le||x_1-x_2||=d\)。此时如果我们想要满足上面的优化目标的话,就可以让\(D(x_1)=k+d\),让\(D(x_2)=k\)。其中k具体是什么无所谓,关键是我们通过\(d\)将判别器在不同分布上的结果限制在了一个较小的范围中。

这样做有什么好处呢?因为我们传统的GAN所使用的判别器是一个最终经过sigmoid输出的神经网络,它的输出曲线肯定是一个S型。在真实分布附近是\(1\),在生成分布附近是\(0\)。而现在我们对判别器施加了这个限制,同时不再在最后一层使用sigmoid,它有有可能是任何形状的线段,只要能让\(D(x_1)-D(x_2)\le d\)即可。如下图所示:

这样做的好处显而易见。传统GAN的判别器是有饱和区的(靠近真实分布和生成分布的地方,函数变化平缓,梯度趋于0)。而现在的GAN如果是一条直线,那就能在训练过程中无差别地提供一个有意义的梯度。

前面说了这么多,核心的观点就是:1不要用sigmoid输出。2换成受限的\(\text{1-Lipschitz}\)来实现一个类似sigmoid的“范围限制”功能。

然而这个\(\text{1-Lipschitz}\)限制应该如何施加?文章中所用的方法非常简单粗暴:截断权重。

如果说,一个判别器\(D\)的形状由其参数决定,当我们需要这个判别器满足\(\text{1-Lipschitz}\)限制,那我们可以通过调整其参数来满足限制。

参数如何调整?由于我们的函数是一个“缓慢变化”的函数,想要让函数缓慢变化,只需要让权值变小一些即可。所以在论文中的处理方法非常直接:在每次参数更新之后,让每个大于\(c\)的参数\(w\)等于\(c\)、让每个小于\(-c\)的参数\(w\)等于\(-c\),即将所有权值参数\(w\)截断在\([-c, c]\)之间。然而这么做实际上保证的并不是\(\text{1-Lipschitz}\),而是\(\text{K-Lipschitz}\),甚至这个\(K\)是多少都是玄学,只能通过调参来测试了。

用一个简单明了的图表示这个过程:

图中斜率比较陡峭的就是没有截断的函数。而截断的函数将会逆时针旋转,从而产生一个类似\(\text{1-Lipschitz}\)限制的效果。

伪代码!

原生的GAN的伪代码是这样:

  • 初始化D的\(\theta_d\)和G的\(\theta_g\)
  • 在每一个训练循环进行:
    • 从数据分布\(P_{data}(x)\)中采样\(m\)个样本\(\{x^1, x^2, \dots, x^m\}\)
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1, z^2, \dots, z^m\}\)
    • 将这些噪声样本输入生成器\(G\),得到生成样本\(\{\tilde{x}^1,\tilde{x}^2,\dots,\tilde{x}^m\}, \tilde{x}^i=G(z^i)\)
    • 更新判别器的参数\(\theta_d\),即最大化:
      • \(\tilde{V}=\frac{1}{m}\sum_{i=1}^m\log D(x^i)+\frac{1}{m}\sum_{i=1}^m\log(1-D(\tilde{x}^i))\)
      • \(\theta_d\leftarrow\theta_d+\eta\nabla\tilde{V}(\theta_d)\)
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1,z^2,\dots,z^m\}\)
    • 更新生成器的参数\(\theta_g\),即最小化:
      • \(\require{cancel}\tilde{V}=\cancel{\frac{1}{m}\sum_{i=1}^m\log D(x^i)}+\frac{1}{m}\sum_{i=1}^m\log(1-D(G(z^i)))\)
      • \(\theta_g\leftarrow\theta_g-\eta\nabla\tilde{V}(\theta_g)\)

而WGAN的伪代码则是这样:

  • 初始化D的\(\theta_d\)和G的\(\theta_g\)
  • 在每一个训练循环进行:
    • 从数据分布\(P_{data}(x)\)中采样\(m\)个样本\(\{x^1, x^2, \dots, x^m\}\)
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1, z^2, \dots, z^m\}\)
    • 将这些噪声样本输入生成器\(G\),得到生成样本\(\{\tilde{x}^1,\tilde{x}^2,\dots,\tilde{x}^m\}, \tilde{x}^i=G(z^i)\)
    • 更新判别器的参数\(\theta_d\),即最大化:
      • \(\tilde{V}=\frac{1}{m}\sum_{i=1}^mD(x^i)-\frac{1}{m}\sum_{i=1}^mD(\tilde{x}^i)\)
      • \(\theta_d\leftarrow\theta_d+\eta\nabla\tilde{V}(\theta_d)\)
      • 更新参数后,截断参数
    • 从先验噪声分布\(P_{prior}(z)\)中采样\(m\)个样本\(\{z^1,z^2,\dots,z^m\}\)
    • 更新生成器的参数\(\theta_g\),即最小化:
      • \(\require{cancel}\tilde{V}=\cancel{\frac{1}{m}\sum_{i=1}^m\log D(x^i)}-\frac{1}{m}\sum_{i=1}^mD(G(z^i))\)
      • \(\theta_g\leftarrow\theta_g-\eta\nabla\tilde{V}(\theta_g)\)

尤其需要注意的是,判别器的输出不再需用sigmoid函数了!并且需要训练\(k\)次判别器,然后只训练一次生成器。

升级版的WGAN(使用gradient penalty)

前面我们提到过,我们要使用weight clipping的技巧来实现对判别器\(D\)的\(\text{1-Lipschitz}\)的等效限制。

而后续一片Improved WGAN的文章中提到,\(\text{1-Lipschitz}\)函数有一个特性:当一个函数是\(\text{1-Lipschitz}\)函数时,它的梯度的norm将永远小于等于1。

也就是说: $$ D\in\text{1-Lipschitz} \leftrightarrow ||\nabla_xD(x)||\le 1 ~~\text{for all x}
$$

记住这里的梯度不是判别器中参数的梯度,而是输入\(x\)的梯度。这个表示也是挺符合直觉的,因为当梯度处处小于等于1时,它的变化就不会太快。当梯度大于1,这个函数的输出有可能就要比输入大,也就不是\(\text{1-Lipschitz}\)函数了。

有了这个理论,我们就可以改变我们的目标函数了。

原来我们优化目标是:

$$ W(P_{data},P_G)=\max_{D\in \text{1-Lipschitz}}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]\}
$$

此时WGAN的优化目标是在\(\text{1-Lipschitz}\)中挑一个函数作为判别器\(D\)。

而Improved WGAN则是这样:

$$ W(P_{data},P_G)=\max_{D}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]-\lambda\int_x\max(0,||\nabla_xD(x)||-1)dx\}
$$

也就是说,现在我们寻找判别器的函数集不再是\(\text{1-Lipschitz}\)中的函数了,而是任意函数。但是后面增加了一项惩罚项。这个惩罚项就能够让选中的判别器函数倾向于是一个“对输入梯度为1的函数”。这样也能实现类似weight clipping的效果。

但与之前遇到的问题一样,求积分无法计算,所以我们用采样的方法去加这个惩罚项,即:

$$ \begin{align} W(P_{data},P_G)=\max_{D}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]\\ -\lambda E_{x\sim P_{penalty}}[\max(0,||\nabla_xD(x)||-1)]\}
\end{align} $$

也就是说,在训练过程中,我们更倾向于得到一个判别器\(D\),它能对从\(P_{penalty}\)中采样得到的每一个\(x\)都能\(||\nabla_xD(x)||\le 1\)

涉及到采样,那就要关系到如何采。而首先是从哪里采?即\(P_{penalty}\)是什么?

Improved WGAN设计了一个特别的\(P_{penalty}\)。它的产生过程如下:

  1. 从\(P_{data}\)中采样一个点
  2. 从\(P_{G}\)中采样一个点
  3. 将这两个点连线
  4. 在连线之上在采样得到一个点,就是一个从\(P_{penalty}\)采样的一个点。

重复上面的过程就能不断采样得到\(x\sim P_{penalty}\)。最终得到下图中的蓝色区域就可以看作是\(P_{penalty}\):

也就是说,我们采样的范围不是整个\(x\),只是\(P_G\)和\(P_{data}\)中间的空间中的一部分。

再更进一步,Improved WGAN真正做的事是这样:

$$ \begin{align} W(P_{data},P_G)=\max_{D}\{E_{x\sim P_{data}}[D(x)]-E_{x\sim P_G}[D(x)]\\ -\lambda E_{x\sim P_{penalty}}[(||\nabla_xD(x)||-1)^2]\}
\end{align} $$

这个惩罚项的目的是让梯度尽可能趋向于等于1。即当梯度大于1或小于1时都会受到惩罚。而原来的惩罚项仅仅在梯度大于1时受到惩罚而已。

这样做是有好处的,就像我们在SVM中强调最大类间距离一样,虽然有多个可以将数据区分开的分类面,但我们希望找到不但能区分数据,还能让区分距离最大的那个分类面。这里这样做的目的是由于可能存在多个判别器,我们想要找到的那个判别器应该有一个“最好的形状”。

一个“好”的判别器应该在\(P_{data}\)附近是尽可能大,要在\(P_G\)附近尽可能小。也就是说处于\(P_{data}\)和\(P_G\)之间的\(P_{penalty}\)区域应该有一个比较“陡峭”的梯度。但是这个陡峭程度是有限制的,这个限制就是\(1\)!

以上就是关于WGAN和Improved WGAN的一些基本思想。WGAN和Improved WGAN (WGAN-GP, Gradient Penalty)只需要对原始的WGAN做一些很简单的修改就能有很好的效果,所以在工程尝试中还是非常值得一试的!

Friskit

继续阅读此作者的更多文章