生成对抗网络的Wasserstein距离
作者:禅与计算机程序设计艺术
1. 背景介绍
生成对抗网络(Generative Adversarial Network,GAN)是近年来机器学习领域最重要的创新之一。GAN通过训练两个相互竞争的神经网络模型—生成器(Generator)和判别器(Discriminator),从而学习生成接近真实数据分布的样本。这种对抗训练的方式可以让生成器生成出高质量、接近真实数据分布的样本。
传统的GAN模型使用Jensen-Shannon散度作为生成器和判别器之间的损失函数。然而,Jensen-Shannon散度存在一些问题,比如当生成分布和真实分布没有重叠时梯度会消失,导致训练陷入困境。为了解决这一问题,Wasserstein GAN (WGAN)被提出,它使用Wasserstein距离作为生成器和判别器之间的loss函数。
2. 核心概念与联系
2.1 Wasserstein距离
Wasserstein距离,也称为Earth Mover’s Distance (EMD),是度量两个概率分布之间距离的一种方法。
给定两个概率分布PPP和QQQ,Wasserstein距离定义为:
W(P,Q)=infγ∈Γ(P,Q)E(x,y)∼γ[∣∣x−y∣∣]W(P,Q) = \inf_{\gamma \in \Gamma(P,Q)} \mathbb{E}_{(x,y)\sim \gamma}[||x-y||]W(P,Q)=γ∈Γ(P,Q)infE(x,y)∼γ[∣∣x−y∣∣]
其中Γ(P,Q)\Gamma(P,Q)Γ(P,Q)表示所有可能的联合分布γ(x,y)\gamma(x,y)γ(x,y)的集合,其边缘分布为PPP和QQQ。直观上来说,Wasserstein距离就是将一个分布变形为另一个分布所需要的最小"工作量"。
与KL散度和JS散度不同,Wasserstein距离是一个真正的度量,满足以下性质:
- 非负性:W(P,Q)≥0W(P,Q) \geq 0W(P,Q)≥0,等号成立当且仅当P=QP=QP=Q
- 对称性:W(P,Q)=W(Q,P)W(P,Q) = W(Q,P)W(P,Q)=W(Q,P)
- 三角不等式:W(P,R)≤W(P,Q)+W(Q,R)W(P,R) \leq W(P,Q) + W(Q,R)W(P,R)≤W(P,Q)+W(Q,R)
这些性质使得Wasserstein距离更适合作为GAN的损失函数。
2.2 WGAN
WGAN通过最小化生成器G和判别器D之间的Wasserstein距离来训练GAN,损失函数定义如下:
minGmaxDW(Pg,Pr)\min_G \max_D W(P_g, P_r)GminDmaxW(Pg</