基于随机微分方程的分数生成建模
作为会议论文发表于 ICLR 2021
作者:
- Yang Song* (斯坦福大学, yangsong@cs.stanford.edu )
- Jascha Sohl-Dickstein (Google Brain, jaschasd@google.com )
- Diederik P. Kingma (Google Brain, durk@google.com )
- Abhishek Kumar (Google Brain, abhishk@google.com )
- Stefano Ermon (斯坦福大学, ermon@cs.stanford.edu )
- Ben Poole (Google Brain, pooleb@google.com )
*部分工作完成于 Google Brain 实习期间
摘要
从数据生成噪声是容易的;从噪声生成数据则是生成建模。
- 我们提出一个随机微分方程(Stochastic Differential Equation, SDE),通过缓慢注入噪声,将复杂数据分布平滑地转化为已知先验分布;以及一个相应的逆向时间 SDE,通过缓慢去除噪声,将先验分布转化回数据分布。
- 关键的是,逆向时间 SDE 仅依赖于扰动数据分布的时变梯度场(即分数)。
- 通过利用基于评分的生成建模的最新进展,我们可以用神经网络准确估计这些分数,并使用数值 SDE 求解器生成样本。
- 我们展示了这一框架整合了先前基于评分的生成建模和扩散概率建模 (diffusion probabilistic modeling) 的方法,从而支持新的采样策略和建模能力。
- 特别地,我们引入了预测器-校正器(predictor-corrector)框架,来纠正离散化逆向时间 SDE 演化中的误差。
- 我们还推导了一个等效的神经常微分方程(neural ODE),该方程从与 SDE 相同的分布中采样,但额外支持精确似然计算(exact likelihood computation)并提高了采样效率。
- 此外,我们提供了一种利利用基于评分的模型解决逆问题(inverse problems)的新方法,并通过类条件生成(class-conditional generation)、图像修复(image inpainting)和着色(colorization)实验进行了验证。
- 结合多种架构改进,我们在 CIFAR-10上实现了无条件图像生成的破纪录性能:初始分数(Inception score)为9.89,FID 为2.20,似然值达到2.99比特/维度(bits/dim)的竞争性水平,并首次从基于分数的生成模型中展示了1024×1024图像的高保真度生成。
1 引言
- 两类成功的概率生成模型(probabilistic generative models)都涉及用缓慢增加的噪声顺序破坏训练数据,然后学习逆转这种破坏以形成数据的生成模型。
- 基于朗之万动力学的分数匹配(Score Matching with Langevin Dynamics, SMLD)(Song & Ermon, 2019) 估计每个噪声尺度下的分数(即数据对数概率密度的梯度),然后在生成过程中使用朗之万动力学(Langevin dynamics)从一系列递减的噪声尺度中采样。
- 去噪扩散概率建模(Denoising Diffusion Probabilistic Modeling, DDPM)(Sohl-Dickstein et al., 2015; Ho et al., 2020) 训练一系列概率模型来逆转每个噪声破坏步骤,利用逆向分布的函数形式使训练易于处理。对于连续状态空间,DDPM 训练目标隐式计算了每个噪声尺度下的分数。
- 因此,我们将这两类模型统称为基于分数的生成模型(score-based generative models)。
基于分数的生成模型及相关技术 (Bordes et al., 2017; Goyal et al., 2017; Du & Mordatch, 2019) 已被证明在图像 (Song & Ermon, 2019; 2020; Ho et al., 2020)、音频 (Chen et al., 2020; Kong et al., 2020)、图 (Niu et al., 2020) 和形状 (Cai et al., 2020) 生成方面是有效的。
为了启用新的采样方法并进一步扩展基于分数的生成模型的能力,我们提出了一个统一框架,通过随机微分方程(SDEs)的视角推广了先前的方法。
具体而言,我们不再用有限数量的噪声分布扰动数据,而是考虑一个随时间演化、服从扩散过程(diffusion process)的连续分布族。
该过程逐步将数据点扩散为随机噪声,由一个预定义的、不依赖数据且无训练参数的 SDE 给出。
通过逆转这个过程,我们可以将随机噪声平滑地塑造成数据以生成样本。
至关重要的是,这个逆向过程满足一个逆向时间 SDE (Anderson, 1982),该 SDE 可以从正向 SDE 结合时间边缘概率密度的分数推导出来。
因此,我们可以通过训练一个时变神经网络来估计分数以近似逆向时间 SDE,然后使用数值 SDE 求解器生成样本。我们的核心思想总结在图1中。
图1:求解逆向时间SDE产生一个基于分数的生成模型。将数据转化为简单噪声分布可以通过一个连续时间SDE实现。如果我们知道每个中间时间步分布 ∇xlogpt(x)\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x})∇xlogpt(x) 的分数,这个 SDE 可以被逆转。
我们提出的框架具有以下几个理论和实践贡献:
- 灵活的采样和似然计算:
- 我们可以使用任何通用 SDE 求解器对逆向时间 SDE 进行积分以进行采样。
- 此外,我们提出了两种对一般 SDE 不可行的特殊方法:
- 预测器-校正器 (Predictor-Corrector, PC)采样器,将数值 SDE 求解器与基于分数的 MCMC 方法(如朗之万 MCMC (Parisi, 1981) 和 HMC (Neal et al., 2011))相结合;
- 统一并改进了现有基于分数模型的采样方法。
- 基于概率流常微分方程 (probability flow ordinary differential equation, ODE) 的确定性采样器(deterministic samplers)。
- 允许通过黑盒 ODE 求解器进行快速自适应采样(fast adaptive sampling),通过潜在编码(latent codes)进行灵活的数据操作(flexible data manipulation),实现唯一可识别的编码(uniquely identifiable encoding),并显著支持精确似然计算(exact likelihood computation)。
- 预测器-校正器 (Predictor-Corrector, PC)采样器,将数值 SDE 求解器与基于分数的 MCMC 方法(如朗之万 MCMC (Parisi, 1981) 和 HMC (Neal et al., 2011))相结合;
- 可控生成 (Controllable generation):
- 我们可以通过 对训练时不可用的信息取条件对训练时不可用的信息取条件对训练时不可用的信息取条件 来调控生成过程,因为条件逆向时间 SDE 可以从一个非传统模型中高效估计。
- 这使得类条件生成、图像修复、着色和其他逆问题成为可能,所有这些都可以使用单个无条件分数模型、无需重新训练。
- 统一框架:
- 我们的框架提供了一种统一的方式来探索和调整各种 SDE 以改进基于分数的生成模型。
- SMLD 和 DDPM 的方法可以被整合到我们的框架中,作为两个不同 SDE 的离散化。
- 尽管最近报道 DDPM (Ho et al., 2020) 比 SMLD (Song & Ermon, 2019; 2020) 实现了更高的样本质量,但我们证明,通过更好的架构和我们框架允许的新采样算法,后者可以迎头赶上——它在 CIFAR-10上实现了新的最先进初始分数(9.89)和 FID 分数(2.20),并首次从基于分数的模型中实现了1024×1024图像的高保真度生成。
- 此外,我们在框架下提出了一种新的 SDE,在均匀去量化(uniformly dequantized)的 CIFAR-10图像上实现了2.99 bits/dim 的似然值,为该任务设定了新记录。
2 背景
2.1 基于朗之万动力学的去噪分数匹配 (SMLD)
令 pσ(x^∣x):=N(x^;x,σ2I)p_\sigma(\hat{\mathbf{x}} \mid \mathbf{x}):=\mathcal{N}\left(\hat{\mathbf{x}} ; \mathbf{x}, \sigma^2 \mathbf{I}\right)pσ(x^∣x):=N(x^;x,σ2I) 为扰动核(perturbation kernel),且 pσ(x^):=∫pdata (x)pσ(x^∣x)dxp_\sigma(\hat{\mathbf{x}}):=\int p_{\text {data }}(\mathbf{x}) p_\sigma(\hat{\mathbf{x}} \mid \mathbf{x}) \mathrm{d} \mathbf{x}pσ(x^):=∫pdata (x)pσ(x^∣x)dx,其中 pdata (x)p_{\text {data }}(\mathbf{x})pdata (x) 表示数据分布。
考虑一个正噪声尺度序列 σmin=σ1<σ2<⋯<σN=σmax\sigma_{\min }=\sigma_1<\sigma_2<\cdots<\sigma_N=\sigma_{\max }σmin=σ1<σ2<⋯<σN=σmax。
通常,σmin\sigma_{\min }σmin 足够小使得 pσmin(x)≈pdata (x)p_{\sigma_{\min }}(\mathbf{x}) \approx p_{\text {data }}(\mathbf{x})pσmin(x)≈pdata (x),σmax\sigma_{\max }σmax 足够大使得 pσmax(x)≈N(x;0,σmax2I)p_{\sigma_{\max }}(\mathbf{x}) \approx \mathcal{N}\left(\mathbf{x} ; \mathbf{0}, \sigma_{\max }^2 \mathbf{I}\right)pσmax(x)≈N(x;0,σmax2I)。
Song & Ermon (2019) 提出训练一个噪声条件分数网络(Noise Conditional Score Network, NCSN),记为 sθ(x,σ)\mathbf{s}_\theta(\mathbf{x}, \sigma)sθ(x,σ),其目标函数是去噪分数匹配 (Vincent, 2011) 目标的加权和:
θ∗=argminθ∑i=1Nσi2Epdata (x)Epσi(x^∣x)[∥sθ(x^,σi)−∇x^logpσi(x^∣x)∥22].(1)
\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\arg \min } \sum_{i=1}^N \sigma_i^2 \mathbb{E}_{p_{\text {data }}(\mathbf{x})} \mathbb{E}_{p_{\sigma_i}(\hat{\mathbf{x}} \mid \mathbf{x})} \left[ \left\| \mathbf{s}_{\boldsymbol{\theta}}\left(\hat{\mathbf{x}}, \sigma_i\right)-\nabla_{\hat{\mathbf{x}}} \log p_{\sigma_i}(\hat{\mathbf{x}} \mid \mathbf{x}) \right\|_2^2 \right]. \quad (1)
θ∗=θargmini=1∑Nσi2Epdata (x)Epσi(x^∣x)[∥sθ(x^,σi)−∇x^logpσi(x^∣x)∥22].(1)
给定充足的数据和模型容量,最优的基于分数的模型 sθ∗(x,σ)\mathbf{s}_{\theta *}(\mathbf{x}, \sigma)sθ∗(x,σ) 在几乎所有地方都匹配 ∇x^logpσ(x)\nabla_{\hat{\mathbf{x}}} \log p_\sigma(\mathbf{x})∇x^logpσ(x),对于 σ∈{σi}i=1N\sigma \in\left\{\sigma_i\right\}_{i=1}^Nσ∈{σi}i=1N。
对于采样,Song & Ermon (2019) 为每个 pσi(x)p_{\sigma_i}(\mathbf{x})pσi(x) 顺序运行 MMM 步朗之万 MCMC 以获取样本:
xim=xim−1+ϵisθ∗(xim−1,σi)+2ϵizim,m=1,2,⋯ ,M,(2)
\mathbf{x}_i^m=\mathbf{x}_i^{m-1}+\epsilon_i \mathbf{s}_{\theta *}\left(\mathbf{x}_i^{m-1}, \sigma_i\right)+\sqrt{2 \epsilon_i} \mathbf{z}_i^m, \quad m=1,2, \cdots, M, \quad (2)
xim=xim−1+ϵisθ∗(xim−1,σi)+2ϵizim,m=1,2,⋯,M,(2)
其中 ϵi>0\epsilon_i>0ϵi>0 是步长,zim\mathbf{z}_i^mzim 是标准正态分布。
上述过程依次对 i=N,N−1,⋯ ,1i=N, N-1, \cdots, 1i=N,N−1,⋯,1 重复执行,其中 xN0∼N(x∣0,σmax2I)\mathbf{x}_N^0 \sim \mathcal{N}\left(\mathbf{x} \mid \mathbf{0}, \sigma_{\max }^2 \mathbf{I}\right)xN0∼N(x∣0,σmax2I) 且当 i<Ni<Ni<N 时 xi0=xi+1M\mathbf{x}_i^0=\mathbf{x}_{i+1}^Mxi0=xi+1M。当 M→∞M \rightarrow \inftyM→∞ 且对所有 iii 有 ϵi→0\epsilon_i \rightarrow 0ϵi→0 时,在一定的正则条件下,x1M\mathbf{x}_1^Mx1M 成为来自 pσmin(x)≈pdata (x)p_{\sigma_{\min }}(\mathbf{x}) \approx p_{\text {data }}(\mathbf{x})pσmin(x)≈pdata (x) 的精确样本。
2.2 去噪扩散概率模型 (DDPM)
Sohl-Dickstein et al. (2015); Ho et al. (2020) 考虑一个正噪声尺度序列 0<β1,β2,⋯ ,βN<10<\beta_1, \beta_2, \cdots, \beta_N<10<β1,β2,⋯,βN<1。
对于每个训练数据点 x0∼pdata (x)\mathbf{x}_0 \sim p_{\text {data }}(\mathbf{x})x0∼pdata (x),构建一个离散马尔可夫链 {x0,x1,⋯ ,xN}\left\{\mathbf{x}_0, \mathbf{x}_1, \cdots, \mathbf{x}_N\right\}{x0,x1,⋯,xN},使得 p(xi∣xi−1)=N(xi;1−βixi−1,βiI)p\left(\mathbf{x}_i \mid \mathbf{x}_{i-1}\right)=\mathcal{N}\left(\mathbf{x}_i ; \sqrt{1-\beta_i} \mathbf{x}_{i-1}, \beta_i \mathbf{I}\right)p(xi∣xi−1)=N(xi;1−βixi−1,βiI),因此 pαi(xi∣x0)=N(xi;αix0,(1−αi)I)p_{\alpha_i}\left(\mathbf{x}_i \mid \mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_i ; \sqrt{\alpha_i} \mathbf{x}_0,\left(1-\alpha_i\right) \mathbf{I}\right)pαi(xi∣x0)=N(xi;αix0,(1−αi)I),其中 αi:=∏j=1i(1−βj)\alpha_i:=\prod_{j=1}^i\left(1-\beta_j\right)αi:=∏j=1i(1−βj)。
类似于SMLD,我们将扰动数据分布表示为 pαi(x^):=∫pdata (x)pαi(x^∣x)dxp_{\alpha_i}(\hat{\mathbf{x}}):=\int p_{\text {data }}(\mathbf{x}) p_{\alpha_i}(\hat{\mathbf{x}} \mid \mathbf{x}) \mathrm{d} \mathbf{x}pαi(x^):=∫pdata (x)pαi(x^∣x)dx。噪声尺度的设定使得 xN\mathbf{x}_NxN 近似服从 N(0,I)\mathcal{N}(\mathbf{0}, \mathbf{I})N(0,I)。反向方向的变分马尔可夫链参数化为 pθ(xi−1∣xi)=N(xi−1;11−βi(xi+βisθ(xi,i)),βiI)p_\theta\left(\mathbf{x}_{i-1} \mid \mathbf{x}_i\right)=\mathcal{N}\left(\mathbf{x}_{i-1} ; \frac{1}{\sqrt{1-\beta_i}}\left(\mathbf{x}_i+\beta_i \mathbf{s}_\theta\left(\mathbf{x}_i, i\right)\right), \beta_i \mathbf{I}\right)pθ(xi−1∣xi)=N(xi−1;1−βi1(xi+βisθ(xi,i)),βiI),并使用证据下界(Evidence Lower Bound, ELBO)的重新加权变体进行训练:
θ∗=argminθ∑i=1N(1−αi)Epdata (x)Epαi(x^∣x)[∥sθ(x^,i)−∇x^logpαi(x^∣x)∥22].(3)
\boldsymbol{\theta}^*=\underset{\boldsymbol{\theta}}{\arg \min } \sum_{i=1}^N \left(1-\alpha_i\right) \mathbb{E}_{p_{\text {data }}(\mathbf{x})} \mathbb{E}_{p_{\alpha_i}(\hat{\mathbf{x}} \mid \mathbf{x})} \left[ \left\| \mathbf{s}_\theta(\hat{\mathbf{x}}, i)-\nabla_{\hat{\mathbf{x}}} \log p_{\alpha_i}(\hat{\mathbf{x}} \mid \mathbf{x}) \right\|_2^2 \right]. \quad (3)
θ∗=θargmini=1∑N(1−αi)Epdata (x)Epαi(x^∣x)[∥sθ(x^,i)−∇x^logpαi(x^∣x)∥22].(3)
在求解方程(3)得到最优模型 sθ∗(x,i)\mathbf{s}_{\theta *}(\mathbf{x}, i)sθ∗(x,i) 后,可以通过从 xN∼N(0,I)\mathbf{x}_N \sim \mathcal{N}(\mathbf{0}, \mathbf{I})xN∼N(0,I) 开始,并遵循估计的逆向马尔可夫链生成样本:
xi−1=11−βi(xi+βisθ∗(xi,i))+βizi,i=N,N−1,⋯ ,1.(4)
\mathbf{x}_{i-1}=\frac{1}{\sqrt{1-\beta_i}}\left(\mathbf{x}_i+\beta_i \mathbf{s}_{\theta *}\left(\mathbf{x}_i, i\right)\right)+\sqrt{\beta_i} \mathbf{z}_i, \quad i=N, N-1, \cdots, 1. \quad (4)
xi−1=1−βi1(xi+βisθ∗(xi,i))+βizi,i=N,N−1,⋯,1.(4)
我们将此方法称为祖先采样(ancestral sampling),因为它对应于从图模型 ∏i=1Npθ(xi−1∣xi)\prod_{i=1}^N p_\theta\left(\mathbf{x}_{i-1} \mid \mathbf{x}_i\right)∏i=1Npθ(xi−1∣xi) 进行祖先采样。
这里描述的方程(3)是Ho et al. (2020) 中的 Lsimple L_{\text {simple }}Lsimple ,其形式旨在暴露与方程(1)的更多相似性。
与方程(1)类似,方程(3)也是去噪分数匹配目标的加权和,这意味着最优模型 sθ∗(x^,i)\mathbf{s}_{\theta *}(\hat{\mathbf{x}}, i)sθ∗(x^,i) 匹配扰动数据分布的分数 ∇xlogpαi(x)\nabla_{\mathbf{x}} \log p_{\alpha_i}(\mathbf{x})∇xlogpαi(x)。
值得注意的是,方程(1)和方程(3)中第 iii 个求和项的权重,即 σi2\sigma_i^2σi2 和 (1−αi)\left(1-\alpha_i\right)(1−αi),与相应扰动核(perturbation kernels)具有相同的函数形式:σi2∝1/E[∥∇xlogpσi(x^∣x)∥22]\sigma_i^2 \propto 1 / \mathbb{E}\left[\left\|\nabla_{\mathbf{x}} \log p_{\sigma_i}(\hat{\mathbf{x}} \mid \mathbf{x})\right\|_2^2\right]σi2∝1/E[∥∇xlogpσi(x^∣x)∥22] 和 (1−αi)∝1/E[∥∇xlogpαi(x^∣x)∥22]\left(1-\alpha_i\right) \propto 1 / \mathbb{E}\left[\left\|\nabla_{\mathbf{x}} \log p_{\alpha_i}(\hat{\mathbf{x}} \mid \mathbf{x})\right\|_2^2\right](1−αi)∝1/E[∥∇xlogpαi(x^∣x)∥22]。
3 基于 SDEs 的分数生成建模
使用多个噪声尺度扰动数据是先前方法成功的关键。我们建议将这一思想进一步推广到无限数量的噪声尺度,使得扰动数据分布随着噪声强度的变化根据一个 SDE 演化。我们框架的概述见图2。
图2:通过SDEs进行基于分数的生成建模概述。我们可以用一个SDE将数据映射到噪声分布(先验)(第3.1节),并逆转这个SDE进行生成建模(第3.2节)。我们还可以逆转相关的概率流ODE(第4.3节),这会产生一个确定性过程,其采样分布与SDE相同。逆向时间SDE和概率流ODE都可以通过估计分数 ∇xlogpt(x)\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x})∇xlogpt(x) 获得(第3.3节)。
3.1 用 SDEs 扰动数据
我们的目标是构建一个扩散过程 {x(t)}t=0T\{\mathbf{x}(t)\}_{t=0}^T{x(t)}t=0T,其索引为连续时间变量 t∈[0,T]t \in [0, T]t∈[0,T],使得 x(0)∼p0\mathbf{x}(0) \sim p_0x(0)∼p0 (我们拥有其独立同分布样本的数据集),且 x(T)∼pT\mathbf{x}(T) \sim p_Tx(T)∼pT (我们拥有其高效生成样本的易处理形式)。
换句话说,p0p_0p0 是数据分布,pTp_TpT 是先验分布。
这个扩散过程可以建模为一个 Itô SDE 的解:
dx=f(x,t)dt+g(t)dw,(5)
\mathrm{d} \mathbf{x} = \mathbf{f}(\mathbf{x}, t) \mathrm{d} t + g(t) \mathrm{d} \mathbf{w}, \quad (5)
dx=f(x,t)dt+g(t)dw,(5)
其中 w\mathbf{w}w 是标准维纳过程(又称布朗运动),f(⋅,t):Rd→Rd\mathbf{f}(\cdot, t): \mathbb{R}^d \rightarrow \mathbb{R}^df(⋅,t):Rd→Rd 是一个向量值函数,称为 x(t)\mathbf{x}(t)x(t) 的漂移系数(drift coefficient),g(⋅):R→Rg(\cdot): \mathbb{R} \rightarrow \mathbb{R}g(⋅):R→R 是一个标量函数,称为 x(t)\mathbf{x}(t)x(t) 的扩散系数(diffusion coefficient)。
为简化表述,我们假设扩散系数是标量(而不是 d×dd \times dd×d 矩阵)且不依赖于 x\mathbf{x}x,但我们的理论可以推广到这些情况(见附录A)。只要系数在状态和时间上都是全局Lipschitz的,该SDE就具有唯一的强解 (Øksendal, 2003)。
我们此后用 pt(x)p_{\mathbf{t}}(\mathbf{x})pt(x) 表示 x(t)\mathbf{x}(t)x(t) 的概率密度,并用 pst(x(t)∣x(s))p_{s t}(\mathbf{x}(t) \mid \mathbf{x}(s))pst(x(t)∣x(s)) 表示从 x(s)\mathbf{x}(s)x(s) 到 x(t)\mathbf{x}(t)x(t) 的转移核(transition kernel),其中 0⩽s<t⩽T0 \leqslant s < t \leqslant T0⩽s<t⩽T。
通常,pTp_TpT 是一个无结构的先验分布,不包含 p0p_0p0 的信息,例如具有固定均值和方差的高斯分布。设计方程(5)中的 SDE 有多种方式,可以将其扩散为固定的先验分布。
我们稍后在3.4节提供几个源自 SMLD 和 DDPM 连续推广的示例。
3.2 通过逆转 SDE 生成样本
从 x(T)∼pT\mathbf{x}(T) \sim p_Tx(T)∼pT 的样本开始并逆转该过程,我们可以获得样本 x(0)∼p0\mathbf{x}(0) \sim p_0x(0)∼p0。Anderson (1982) 的一个显著结果表明,扩散过程的逆过程也是一个扩散过程,在时间上向后运行,并由逆向时间 SDE 给出:
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dw‾,(6)
\mathrm{d} \mathbf{x} = \left[ \mathbf{f}(\mathbf{x}, t) - g(t)^2 \nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x}) \right] \mathrm{d} t + g(t) \mathrm{d} \overline{\mathbf{w}}, \quad (6)
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dw,(6)
其中 w‾\overline{\mathbf{w}}w 是时间从 TTT 向后流动到 000 时的标准维纳过程,dt\mathrm{d} tdt 是无穷小的负时间步长。一旦知道每个边缘分布 pt(x)p_{\mathbf{t}}(\mathbf{x})pt(x) 的分数 ∇xlogpt(x)\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x})∇xlogpt(x),我们就可以从方程(6)推导出逆向扩散过程并对其进行模拟以从 p0p_0p0 采样。
3.3 为 SDE 估计分数
分布的分数可以通过在样本上使用分数匹配(score matching)(Hyvärinen, 2005; Song et al., 2019a) 训练一个基于分数的模型来估计。为了估计 ∇xlogpt(x)\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x})∇xlogpt(x),我们可以训练一个时变的基于分数的模型 sθ(x,t)\mathbf{s}_\theta(\mathbf{x}, t)sθ(x,t),通过方程(1)和(3)的连续推广形式:
θ∗=argminθEt{λ(t)Ex(0)Ex(t)∣x(0)[∥sθ(x(t),t)−∇x(t)logp0t(x(t)∣x(0))∥22]},(7)
\boldsymbol{\theta}^* = \underset{\boldsymbol{\theta}}{\arg \min } \mathbb{E}_t \left\{ \lambda(t) \mathbb{E}_{\mathbf{x}(0)} \mathbb{E}_{\mathbf{x}(t) \mid \mathbf{x}(0)} \left[ \left\| \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}(t), t) - \nabla_{\mathbf{x}(t)} \log p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0)) \right\|_2^2 \right] \right\}, \quad (7)
θ∗=θargminEt{λ(t)Ex(0)Ex(t)∣x(0)[sθ(x(t),t)−∇x(t)logp0t(x(t)∣x(0))22]},(7)
其中 λ:[0,T]→R>0\lambda: [0, T] \rightarrow \mathbb{R}_{>0}λ:[0,T]→R>0 是一个正加权函数,ttt 在 [0,T][0, T][0,T] 上均匀采样,x(0)∼p0(x)\mathbf{x}(0) \sim p_0(\mathbf{x})x(0)∼p0(x),且 x(t)∼p0t(x(t)∣x(0))\mathbf{x}(t) \sim p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))x(t)∼p0t(x(t)∣x(0))。
在充足的数据和模型容量下,分数匹配确保方程(7)的最优解,记为 sθ(x,t)\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}, t)sθ(x,t),在几乎所有 x\mathbf{x}x 和 ttt 处等于 ∇xlogpt(x)\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x})∇xlogpt(x)。
如同在SMLD和DDPM中,我们通常可以选择 λ∝1/E[∥∇x(t)logp0t(x(t)∣x(0))∥22]\lambda \propto 1 / \mathbb{E}\left[ \left\| \nabla_{\mathbf{x}(t)} \log p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0)) \right\|_2^2 \right]λ∝1/E[∇x(t)logp0t(x(t)∣x(0))22]。
注意方程(7)使用了去噪分数匹配,但其他分数匹配目标,如切片分数匹配(sliced score matching)(Song et al., 2019a) 和有限差分分数匹配(finite-difference score matching)(Pang et al., 2020),也适用于此。
我们通常需要知道转移核 p0t(x(t)∣x(0))p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))p0t(x(t)∣x(0)) 才能高效求解方程(7)。
当 f(⋅,t)\mathbf{f}(\cdot, t)f(⋅,t) 是仿射(affine)时,转移核总是高斯分布,其均值和方差通常有闭式解,可以用标准技术获得(参见Särkkä & Solin, 2019 的5.5节)。
对于更一般的SDE,我们可以求解Kolmogorov前向方程(forward equation)(Øksendal, 2003) 来获得 p0t(x(t)∣x(0))p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))p0t(x(t)∣x(0))。或者,我们可以模拟SDE以从 p0t(x(t)∣x(0))p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))p0t(x(t)∣x(0)) 采样,并替换 ∇x(t)logp0t(x(t)∣x(0))\nabla_{\mathbf{x}(t)} \log p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))∇x(t)logp0t(x(t)∣x(0)) 的组合(见附录 A)。
3.4 示例:VE、VP SDEs 及其他
- SMLD 和 DDPM 中使用的噪声扰动可以看作是两个不同 SDE 的离散化。
- 下面我们简要讨论,更多细节见附录 B。
- 当使用总共 NNN 个噪声尺度时,SMLD的每个扰动核 pσi(x∣x0)p_{\sigma_i}(\mathbf{x} \mid \mathbf{x}_0)pσi(x∣x0) 对应于以下马尔可夫链中的 xi\mathbf{x}_ixi 分布:xi=xi−1+σi2−σi−12zi−1,i=1,⋯ ,N,(8)\mathbf{x}_i = \mathbf{x}_{i-1} + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} \mathbf{z}_{i-1}, \quad i=1, \cdots, N, \quad (8)xi=xi−1+σi2−σi−12zi−1,i=1,⋯,N,(8)
- 其中 zi−1∼N(0,I)\mathbf{z}_{i-1} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})zi−1∼N(0,I),我们引入了 σ0=0\sigma_0=0σ0=0 以简化符号。
- 在 N→∞N \rightarrow \inftyN→∞ 的极限下,{σi}i=1N\left\{\sigma_i \right\}_{i=1}^N{σi}i=1N 变成一个函数 σ(t)\sigma(t)σ(t),zi\mathbf{z}_izi 变成 z(t)\mathbf{z}(t)z(t),马尔可夫链 {xi}i=1N\{\mathbf{x}_i\}_{i=1}^N{xi}i=1N 变成一个连续的随机过程 {x(t)}t=01\{\mathbf{x}(t)\}_{t=0}^1{x(t)}t=01,其中我们使用连续时间变量 t∈[0,1]t \in [0,1]t∈[0,1] 进行索引,而不是整数 iii。
- 过程 {x(t)}t=01\{\mathbf{x}(t)\}_{t=0}^1{x(t)}t=01 由以下 SDE 给出:dx=d[σ(t)]2dtdw.(9)\mathrm{d} \mathbf{x} = \sqrt{\frac{d [\sigma(t)]^2}{dt}} \mathrm{d} \mathbf{w}. \quad (9)dx=dtd[σ(t)]2dw.(9)
- 方程(9)中的SDE在 t→∞t \rightarrow \inftyt→∞ 时总是产生方差爆炸的过程:将方程(9)称为方差爆炸(Variance Exploding, VE)SDE
- 类似地,对于DDPM的扰动核 {pαi(x∣x0)}i=1N\left\{ p_{\alpha_i}(\mathbf{x} \mid \mathbf{x}_0 \right)\}_{i=1}^N{pαi(x∣x0)}i=1N,离散马尔可夫链是:xi=1−βixi−1+βizi−1,i=1,…,N.(10)\mathbf{x}_i = \sqrt{1-\beta_i} \mathbf{x}_{i-1} + \sqrt{\beta_i} \mathbf{z}_{i-1}, \quad i=1,\ldots,N. \quad (10)xi=1−βixi−1+βizi−1,i=1,…,N.(10)
- 当 N→∞N \rightarrow \inftyN→∞ 时,方程(10)收敛到以下 SDE:dx=−12β(t)xdt+β(t)dw.(11)\mathrm{d} \mathbf{x} = -\frac{1}{2} \beta(t) \mathbf{x} \mathrm{d} t + \sqrt{\beta(t)} \mathrm{d} \mathbf{w}. \quad (11)dx=−21β(t)xdt+β(t)dw.(11)
- 方程(11)中的 SDE 在初始分布具有单位方差时产生方差为1的固定方差过程(证明见附录 B):将方程(11)称为方差保持(Variance Preserving, VP)SDE
- 因此,SMLD和DDPM中使用的噪声扰动分别对应于方程(9)和(11)中SDE的离散化。
- 受 VP SDE 启发,我们提出了一种新型 SDE,其在似然方面表现特别好(见第4.3节),由下式给出:dx=−12β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dw.(12)\mathrm{d} \mathbf{x} = -\frac{1}{2} \beta(t) \mathbf{x} \mathrm{d} t + \sqrt{\beta(t)(1-e^{-2 \int_0^t \beta(s) \mathrm{d} s})} \mathrm{d} \mathbf{w}. \quad (12)dx=−21β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dw.(12)
- 当使用相同的 β(t)\beta(t)β(t) 并从相同的初始分布开始时,由方程(12)诱导的随机过程的方差在每一步都小于或等于由 VP SDE 诱导的方差(证明见附录 B)。
- 因此,我们将方程(12)命名为次方差保持(sub-Variance Preserving, sub-VP)SDE。
- 由于VE、VP和sub-VP SDEs都具有仿射漂移系数,它们的扰动核 p0t(x(t))∣x(0))p_{0 t}(\mathbf{x}(t)) \mid \mathbf{x}(0))p0t(x(t))∣x(0)) 都是高斯的,并且可以以闭式计算,如第3.3节所讨论的。这使得使用方程(7)进行训练特别高效。
4 求解逆向 SDE
在训练了一个时变的基于分数的模型 sθ\mathbf{s}_\thetasθ 之后,我们可以用它来构造逆向时间SDE,然后使用数值方法进行模拟,以从 p0p_0p0 生成样本。
表1:在 CIFAR-10上比较不同的逆向时间 SDE 求解器
方差爆炸 SDE (SMLD) | 方差保持 SDE (DDPM) | ||||||||
---|---|---|---|---|---|---|---|---|---|
采样器 | P1000 | P2000 | C2000 | PC1000 | P1000 | P2000 | C2000 | PC1000 | |
预测器 (Predictor) | |||||||||
祖先采样 (ancestral sampling) | 4.98 ±\pm± .06 | 4.88 ±\pm± .06 | 3.62 ±\pm± .03 | 3.24 ±\pm± .02 | 3.24 ±\pm± .02 | 3.21 ±\pm± .02 | |||
逆向扩散 (reverse diffusion) | 4.79 ±\pm± .07 | 4.74 ±\pm± .08 | 20.43 ±\pm± .07 | 3.60 ±\pm± .02 | 3.21 ±\pm± .02 | 3.19 ±\pm± .02 | 19.06 ±\pm± .06 | 3.18 ±\pm± .01 | |
概率流 (probability flow) | 15.41 ±\pm± .15 | 10.54 ±\pm± .08 | 3.51 ±\pm± .04 | 3.59 ±\pm± .04 | 3.23 ±\pm± .03 | 3.06 ±\pm± .03 |
表1:在 CIFAR-10上比较不同的逆向时间 SDE 求解器。阴影区域表示使用相同计算量(分数函数评估次数)获得的结果。均值和标准差在五次采样运行中报告。 “P1000” 或 “P2000”:仅使用预测器的采样器,分别使用1000或2000步。 “C2000”:仅使用校正器的采样器,使用2000步校正。“PC1000”:预测器-校正器(PC)采样器,使用1000步预测器和1000步校正器。
4.1 通用数值 SDE 求解器
数值求解器为 SDEs 提供近似轨迹。存在许多用于求解 SDEs 的通用数值方法,例如 Euler-Maruyama 和随机 Runge-Kutta 方法 (Kloeden & Platen, 2013),它们对应于对随机动力学的不同离散化。我们可以将其中任何一种应用于逆向时间 SDE 进行样本生成。
祖先采样(Ancestral sampling),即 DDPM 的采样方法(方程(4)),实际上对应于逆向时间 VP SDE(方程(11))的一种特殊离散化(见附录 E)。然而,为新 SDE 推导祖先采样规则可能非平凡。为了弥补这一点,我们提出了逆向扩散采样器(reverse diffusion samplers)(细节见附录 E),它以与正向 SDE 相同的方式离散化逆向时间 SDE,因此在给定正向离散化后可以容易地推导出来。如表1所示,对于 CIFAR-10上的 SMLD 和 DDPM 模型,逆向扩散采样器的性能略优于祖先采样(DDPM 类型的祖先采样也适用于 SMLD 模型,见附录 F)。
4.2 预测器-校正器采样器
与通用SDE不同,我们拥有可以利用的额外信息。由于我们有一个基于分数的模型 sθ∗(x,t)≈∇xlogpt(x)\mathbf{s}_{\theta *}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})sθ∗(x,t)≈∇xlogpt(x)),我们可以采用基于分数的MCMC方法,例如朗之万MCMC (1981) 或HMC (Neal et al., 2011),直接从 ptp_tpt 采样,并纠正数值 SDE 求解器的解。
具体来说,在每个时间步,数值 SDE 求解器首先给出下一个时间步样本的估计,我们称之为扮演“预测器(predictor)”的角色。然后,基于分数的 MCMC 方法校正估计样本的边缘分布误差,扮演“校正器(corrector)”的角色。这个想法类似于预测器-校正器方法(Predictor-Corrector methods),这是一类用于求解方程组的数值延拓技术 (Allgower & Georg, 2012),我们类似地将我们的混合采样算法命名为预测器-校正器(Predictor-Corrector, PC)采样器。伪代码和完整描述请参见附录 G。PC 采样器推广了 SMLD 和 DDPM 的原始采样方法:前者使用恒等函数作为预测器,退火朗之万动力学(annealed Langevin dynamics)作为校正器;后者使用祖先采样作为预测器,恒等函数作为校正器。
我们在使用原始离散目标(方程(1)和(3))训练的 SMLD 和 DDPM 模型上测试 PC 采样器(参见附录 G 中的算法2和3)。这证明了 PC 采样器与使用固定数量噪声尺度训练的基于分数模型的兼容性。我们在表4中总结了不同采样器的性能,其中概率流(probability flow)是将在第4.3节讨论的预测器。详细的实验设置和额外结果见附录 G。我们观察到,我们的逆向扩散采样器总是优于祖先采样,并且仅使用校正器的方法(C2000)在相同计算量下比其他竞争者(P2000, PC1000)表现更差(实际上,为了匹配其他采样器的性能,我们需要每个噪声尺度更多的校正步骤,从而需要更多计算)。对于所有预测器,为每个预测器步骤添加一个校正器步骤(PC1000)会使计算量翻倍,但总能提高样本质量(相对于 P1000)。此外,它通常比在不添加校正器的情况下将预测器步数加倍(P2000)更好,对于 SMLD/DDPM 模型,我们需要临时插值噪声尺度(附录 G 中有详细说明)。在图9(附录 G)中,我们额外提供了在256×256 LSUN 图像上使用 VE SDE 和连续目标方程(7)训练的模型的定性比较,其中在计算量相当的情况下,使用适当数量的校正器步骤时,PC 采样器明显优于仅使用预测器的采样器。
表2:CIFAR-10上的负对数似然(NLLs)和 FID(ODE)
模型 | NLL 测试集 ↓\downarrow↓ | FID ↓\downarrow↓ |
---|---|---|
RealNVP (Dinh et al., 2016) | 3.49 | - |
iResNet (Behrmann et al., 2019) | 3.45 | - |
Glow (Kingma & Dhariwal, 2018) | 3.35 | - |
MintNet (Song et al., 2019b) | 3.32 | - |
Residual Flow (Chen et al., 2019) | 3.28 | 46.37 |
FFJORD (Grathwohl et al., 2018) | 3.40 | - |
HAAA (H4A4, 2018) | 3.98 | - |
DDPM (LLL) (Ho et al., 2020) | ≤3.70\leq 3.70≤3.70 | 13.51 |
DDPM (LsimpleL_{\text{simple}}Lsimple) (Ho et al., 2020) | ≤3.75\leq 3.75≤3.75 | 3.17 |
DDPM | 3.28 | 3.37 |
DDPM 连续 (VP) | 3.21 | 3.69 |
DDPM 连续 (sub-VP) | 3.05 | 3.56 |
DDPM + 连续 (VP) | 3.16 | 3.93 |
DDPM + 连续 (sub-VP) | 3.02 | 3.16 |
DDPM + 连续 (深度, VP) | 3.13 | 3.08 |
DDPM + 连续 (深度, sub-VP) | 2.99 | 2.92 |
表2:CIFAR-10上的负对数似然(NLLs)和 FID(使用 ODE 求解器采样)。
表3:CIFAR-10样本质量
模型 | FID↓ | IS↑ |
---|---|---|
条件生成 (Conditional) | ||
BigGAN (Brock et al., 2018) | 14.73 | 9.22 |
StyleGAN2-ADA (Karras et al., 2020a) | 2.42 | 10.14 |
无条件生成 (Unconditional) | ||
StyleGAN2-ADA (Karras et al., 2020a) | 2.92 | 9.83 |
NCSN (Song & Ermon, 2019) | 25.32 | 8.87 ± .12 |
NCSNv2 (Song & Ermon, 2020) | 10.87 | 8.40 ± .07 |
DDPM (Ho et al., 2020) | 3.17 | 9.46 ± .11 |
DDPM++ | 2.78 | 9.64 |
DDPM++ 连续 (VP) | 2.55 | 9.58 |
DDPM++ 连续 (sub-VP) | 2.61 | 9.56 |
DDPM++ 连续 (深度, VP) | 2.41 | 9.68 |
DDPM++ 连续 (深度, sub-VP) | 2.41 | 9.57 |
NCSN++ | 2.45 | 9.73 |
NCSN++ 连续 (VE) | 2.38 | 9.83 |
NCSN++ 连续 (深度, VE) | 2.20 | 9.89 |
表3:CIFAR-10样本质量。
4.3 概率流 ODE
基于分数的模型启用了另一种求解逆向时间SDE的数值方法。对于所有扩散过程,存在一个相应的确定性过程(deterministic process),其轨迹与SDE共享相同的边缘概率密度 {ρt(x)}t=0T\{\rho_t(x)\}_{t=0}^T{ρt(x)}t=0T。这个确定性过程满足一个ODE(更多细节见附录D.1):
dxdt=[f(x,t)−12g(t)2∇xlogρt(x)]dt,(13) \frac{dx}{dt} = \left[ f(x, t) - \frac{1}{2} g(t)^2 \nabla_x \log \rho_t(x) \right] dt, \quad (13) dtdx=[f(x,t)−21g(t)2∇xlogρt(x)]dt,(13)
一旦分数已知,该 ODE 可以从 SDE 确定。我们将方程(13)中的 ODE 命名为概率流 ODE(probability flow ODE)。当分数函数由时变的基于分数的模型(通常是神经网络)近似时,这就是神经 ODE(neural ODE)(Chen et al., 2018) 的一个例子。
精确似然计算
利用与神经ODE的联系,我们可以通过瞬时变量变换公式(instantaneous change of variables formula)(Chen et al., 2018) 计算由方程(13)定义的似然。这允许我们在任何输入数据上计算精确似然(exact likelihood)(细节见附录D.2)。作为一个例子,我们在表2中报告了在 CIFAR-10数据集上以比特/维度(bits/dim)衡量的负对数似然(negative log-likelihoods, NLLs)。我们在均匀去量化(uniformly dequantized)的数据上计算对数似然,并且仅与以相同方式评估的模型进行比较(省略了使用变分去量化(variational dequantization)(Ho et al., 2019) 或离散数据评估的模型),除了 DDPM (LL/sample),其 ELBO 值(标注为*)是在离散数据上报告的。主要结果:(i) 对于 Ho et al. (2020) 中相同的 DDPM 模型,由于我们的似然是精确的,我们获得了比 ELBO 更好的 bits/dim;(ii) 使用相同的架构,我们使用方程(7)中的连续目标训练了另一个 DDPM 模型(即 DDPM cont.),进一步改善了似然;(iii) 使用 sub-VP SDEs,我们总是获得比 VP SDEs 更高的似然;(iv) 通过改进的架构(即 DDPM++ cont.,细节见第4.4节)和 sub-VP SDE,我们甚至在没有最大似然训练的情况下,在均匀去量化的 CIFAR-10上创下了2.99 bits/dim 的新记录。
操作潜在表示
通过对方程(13)积分,我们可以将任何数据点 x(0)x(0)x(0) 编码到潜在空间 x(T)x(T)x(T)。解码可以通过对相应的逆向时间SDE的ODE进行积分来实现。正如在其他可逆模型(如神经ODE和标准化流(normalizing flows)(Dinh et al., 2016; Kingma & Dhariwal, 2018))中所做的那样,我们可以操作这个潜在表示进行图像编辑,例如插值(interpolation)和温度缩放(temperature scaling)(见图3和附录D.4)。
唯一可识别编码
与大多数当前的可逆模型不同,我们的编码是唯一可识别的(uniquely identifiable),这意味着在充足的训练数据、模型容量和优化精度下,输入的编码由数据分布唯一确定 (Roeder et al., 2020)。这是因为我们的正向SDE(方程(5))没有可训练参数,并且其相关的概率流ODE(方程(13))在给定完美估计的分数时提供相同的轨迹。我们在附录D.5中提供了关于此性质的额外经验验证。
高效采样
与神经ODE一样,我们可以通过从不同的最终条件 x(T)∼ρTx(T) \sim \rho_Tx(T)∼ρT 求解方程(13)来采样 x(0)∼ρ0x(0) \sim \rho_0x(0)∼ρ0。使用固定的离散化策略,我们可以生成具有竞争力的样本,特别是与校正器结合使用时(表1,“概率流采样器”,细节见附录D.3)。使用黑盒ODE求解器(Dormand & Prince, 1980)不仅产生高质量样本(表2,细节见附录D.4),而且允许我们显式地在精度和效率之间进行权衡。使用更大的误差容忍度(error tolerance),函数评估次数可以减少90%以上而不影响样本的视觉质量(图3)。
4.4 架构改进
我们探索了几种新的架构设计,用于改进基于 VE 和 VP SDEs 的基于分数模型(细节见附录 H),其中我们使用与 SMLD/DDPM 相同的离散目标训练模型。由于相似性,我们将 VP SDEs 的架构直接转移到 sub-VP SDEs。我们的 VE SDE 最优架构,命名为 NCSN++,在使用 PC 采样器时在 CIFAR-10上实现了2.45的 FID,而我们 VP SDE 的最优架构,称为 DDPM++,实现了2.78的 FID。
通过切换到方程(7)中的连续训练目标并增加网络深度,我们可以进一步提高所有模型的样本质量。由此产生的架构在表3中分别表示为用于 VE 和 VP/sub-VP SDEs 的 NCSN++ cont. 和 DDPM++ cont.。表3中报告的结果是训练过程中 FID 最小时的检查点(checkpoint),样本使用 PC 采样器生成。相比之下,表2中的 FID 分数和 NLL 值是针对最后一个训练检查点报告的,样本使用黑盒 ODE 求解器获得。如表3所示,VE SDEs 通常比 VP/sub-VP SDEs 提供更好的样本质量,但我们也凭经验观察到它们的似然比对应的 VP/sub-VP SDEs 差。这表明从业者可能需要针对不同的领域和架构尝试不同的 SDE。
我们用于样本质量的最佳模型 NCSN++ cont. (deep, VE) 使网络深度加倍,并为CIFAR-10的无条件生成设定了初始分数和FID的新记录。令人惊讶的是,我们可以在不需要标记数据的情况下获得比先前最佳条件生成模型更好的FID。结合所有改进,我们还在CelebA-HQ 1024×10241024 \times 10241024×1024 上获得了来自基于分数模型的第一个高保真样本集(见附录H.3)。我们用于似然的最佳模型 DDPM++ cont. (deep, sub-VP) 同样使网络深度加倍,并使用方程(7)中的连续目标实现了 2.99 bits/dim 的对数似然。据我们所知,这是在均匀去量化的 CIFAR-10上的最高似然。
5 可控生成
我们框架的连续结构不仅允许我们从 p0p_0p0 生成数据样本,而且如果 pt(y∣x(t))p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}(\mathbf{t}))pt(y∣x(t)) 已知,还可以从 p0(x(0)∣y)p_0(\mathbf{x}(0) \mid \mathbf{y})p0(x(0)∣y) 生成样本。给定如方程(5)的正向SDE,我们可以通过从 pT(x(T)∣y)p_T(\mathbf{x}(T) \mid \mathbf{y})pT(x(T)∣y) 开始并求解条件逆向时间SDE来从 pt(x(t)∣y)p_{\mathbf{t}}(\mathbf{x}(t) \mid \mathbf{y})pt(x(t)∣y) 采样:
dx={f(x,t)−g(t)2[∇xlogpt(x)+∇xlogpt(y∣x)]}dt+g(t)dw.(14) \mathrm{d} \mathbf{x} = \left\{ \mathbf{f}(\mathbf{x}, t) - g(t)^2 \left[ \nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{x}) + \nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}) \right] \right\} \mathrm{d} t + g(t) \mathrm{d} \mathbf{w}. \quad (14) dx={f(x,t)−g(t)2[∇xlogpt(x)+∇xlogpt(y∣x)]}dt+g(t)dw.(14)
通常,一旦给定前向过程梯度 ∇xlogpt(y∣x(t))\nabla_{\mathbf{x}} \log p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}(t))∇xlogpt(y∣x(t)) 的估计,我们就可以使用方程(14)用基于分数的生成模型解决一大类逆问题。在某些情况下,可以训练一个单独的模型来学习前向过程 logpt(y∣x(t))\log p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}(t))logpt(y∣x(t)) 并计算其梯度。否则,我们可以利用启发式方法和领域知识来估计梯度。在附录I.4中,我们提供了一种广泛适用的方法来获得这种估计,而无需训练辅助模型。
我们考虑使用这种方法进行可控生成的三种应用:类条件生成(class-conditional generation)、图像修复(image inpainting) 和着色(colorization)。当 y\mathbf{y}y 表示类别标签时,我们可以训练一个时变分类器 pt(y∣x(t))p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}(t))pt(y∣x(t)) 用于类条件采样。由于正向SDE是易处理的,我们可以轻松创建训练数据对 (x(t),y)(\mathbf{x}(t), \mathbf{y})(x(t),y),首先从数据集中采样 (x(0),y)(\mathbf{x}(0), \mathbf{y})(x(0),y),然后获得 x(t)∼p0t(x(t)∣x(0))\mathbf{x}(t) \sim p_{0 t}(\mathbf{x}(t) \mid \mathbf{x}(0))x(t)∼p0t(x(t)∣x(0))。之后,我们可以使用类似于方程(7)的交叉熵损失混合在不同时间步上来训练时变分类器 pt(y∣x(t))p_{\mathbf{t}}(\mathbf{y} \mid \mathbf{x}(t))pt(y∣x(t))。我们在图4(左)中提供了类条件 CIFAR-10样本,更多细节和结果见附录 I。
图4:左:32×3232 \times 3232×32 CIFAR-10上的类条件样本。上四行是汽车,下四行是马。右:256×256256 \times 256256×256 LSUN 上的修复(上两行)和着色(下两行)结果。第一列是原始图像,第二列是掩码/灰度图像,其余列是采样的图像补全或着色结果。
插补(Imputation)是条件采样的一个特例。假设我们有一个不完整的数据点 y\mathbf{y}y,其中只有某个子集 Ω(y)\Omega(\mathbf{y})Ω(y) 是已知的。插补等同于从 p(x(0)∣Ω(x(0))=y)p(\mathbf{x}(0) \mid \Omega(\mathbf{x}(0))=\mathbf{y})p(x(0)∣Ω(x(0))=y) 采样,我们可以使用无条件模型实现(见附录I.2)。着色是插补的一个特例,除了已知的数据维度是耦合的。我们可以使用正交线性变换(orthogonal linear transformation)解耦这些数据维度,并在变换后的空间中进行插补(细节见附录I.3)。图4(右)显示了使用无条件时变分数模型实现的修复和着色结果。
6 结论
我们提出了一个基于 SDEs 的分数生成建模框架。我们的工作能够更好地理解现有方法,引入新的采样算法,支持精确似然计算,实现唯一可识别的编码,支持潜在代码操作,并将新的条件生成能力引入基于分数的生成模型家族。
虽然我们提出的采样方法改进了结果并实现了更高效的采样,但在相同数据集上采样仍然比 GANs (Goodfellow et al., 2014) 慢。识别将基于分数生成模型的稳定学习与 GANs 等隐式模型(implicit models)的快速采样相结合的方法仍然是一个重要的研究方向。此外,当可以访问分数函数时,可以使用多种采样器,这引入了许多超参数。未来的工作将受益于改进的方法来自动选择和调整这些超参数,以及对各种采样器的优缺点进行更广泛的研究。