FLOW MATCHING FOR GENERATIVE MODELING翻译

摘要

我们引入了一种基于 Continuous Normalizing Flows (CNF) 的生成建模新范式,使我们能够以前所未有的规模训练 CNF。具体而言,我们提出了 Flow Matching (FM) 的概念,这是一种无需模拟的 CNF 训练方法,该方法基于固定条件概率路径的回归矢量场。Flow Matching 与一类通用的高斯概率路径兼容,用于在噪声和数据样本之间进行转换——这将现有的扩散路径归纳为特定实例。有趣的是,我们发现将 FM 与扩散路径结合使用,可以为训练扩散模型提供一种更稳健、更稳定的替代方案。此外,Flow Matching 为使用其他非扩散概率路径训练 CNF 开辟了道路。一个特别值得关注的例子是使用最优传输 (OT) 位移插值来定义条件概率路径。这些路径比扩散路径更高效,可以提供更快的训练和采样速度,并带来更好的泛化效果。在 ImageNet 上使用 Flow Matching 训练 CNF,在似然度和样本质量方面均比其他基于扩散的方法获得更好的性能,并且可以使用现成的数值 ODE 求解器快速可靠地生成样本。

1.INTRODUCTION

在这里插入图片描述

深度生成模型是一类旨在从未知数据分布中进行估计和采样的深度学习算法。近年来,生成模型领域取得了一系列令人瞩目的进展,例如在图像生成领域,这主要得益于基于扩散的模型的可扩展性和相对稳定的训练。然而,由于局限于简单的扩散过程,采样概率路径的空间相当狭窄,导致训练时间非常长,并且需要采用专门的方法来实现高效采样。

在本文中,我们探讨了 Continuous Normalizing Flows (CNF) 的通用且确定性的框架。CNF 能够建模任意概率路径,尤其以涵盖由扩散过程建模的概率路径而闻名。然而,除了可以通过例如 denoising score matching 等方法高效训练的扩散之外,尚无已知的可扩展 CNF 训练算法。事实上,最大似然训练(例如 Grathwohl 等人 (2018))需要昂贵的数值 ODE 模拟,而现有的无需模拟的方法要么涉及难以处理的积分,要么涉及有偏梯度。

本研究的目标是提出 Flow Matching (FM),这是一种高效的无需模拟的 CNF 模型训练方法,允许采用通用概率路径来监督 CNF 训练。重要的是,FM 打破了除扩散之外可扩展 CNF 训练的障碍,并且无需推理扩散过程,直接使用概率路径。

具体来说,我们提出了 Flow Matching 目标(第 3 节),这是一个简单直观的训练目标,用于回归到生成所需概率路径的目标向量场。我们首先证明,我们可以通过每个示例(即条件)公式构建此类目标向量场。然后,受 denoising score matching 的启发,我们证明了一个称为 Conditional Flow Matching (CFM) 的基于示例的训练目标,它提供等效梯度,并且不需要明确了解难以处理的目标向量场。此外,我们讨论了可用于 Flow Matching 的一系列基于示例的概率路径(第 4 节),它将现有的扩散路径作为特殊实例。即使在扩散路径上,我们也发现使用 FM 可以提供更稳健、更稳定的训练,并且比 score matching 获得更优异的性能。此外,这一系列概率路径还包含一个特别有趣的情况:对应于最优传输 (OT) 位移插值的向量场。我们发现,条件 OT 路径比扩散路径更简单,形成直线轨迹,而扩散路径则产生曲线路径。这些特性似乎从经验上转化为更快的训练、更快的生成和更好的性能。

我们在 ImageNet(一个规模庞大且高度多样化的图像数据集)上通过最优传输路径构建 Flow Matching 算法进行了实证验证。我们发现,相比其他基于扩散的竞争方法,我们可以轻松训练模型,使其在似然估计和样本质量方面均取得优异表现。此外,我们发现,与现有方法相比,我们的模型在计算成本和样本质量之间取得了更佳的平衡。图 1 展示了我们模型中选定的 128×128 非条件 ImageNet 样本。

2.PRELIMINARIES: CONTINUOUS NORMALIZING FLOWS

Rd\mathbb R^dRd 表示数据空间,其中数据点为 x=(x1,...,xd)∈Rdx = (x^1, ..., x^d) ∈ \mathbb R^dx=(x1,...,xd)Rd。本文中使用的两个重要对象是:概率密度路径 p:[0,1]×Rd→R>0p: [0, 1] × \mathbb R^d → \mathbb R_{>0}p:[0,1]×RdR>0,它是一个时间相关的概率密度函数,即 ∫pt(x)dx=1\int p_t(x)dx = 1pt(x)dx=1;以及时间相关的矢量场 v:[0,1]×Rd→Rdv: [0, 1] × \mathbb R^d → \mathbb R^dv:[0,1]×RdRd。矢量场 vtv_tvt 可用于构建时间相关的微分同胚映射,称为流 ϕ:[0,1]×Rd→Rd\phi: [0, 1] × \mathbb R^d → \mathbb R^dϕ:[0,1]×RdRd,其通过常微分方程 (ODE) 定义:

ddtϕt(x)=vt(ϕt(x))(1)\frac{d}{dt}\phi_t(x)=v_t(\phi_t(x))\tag{1}dtdϕt(x)=vt(ϕt(x))(1)

ϕ0(x)=x(2)\phi_0(x)=x\tag{2}ϕ0(x)=x(2)

此前,Chen et al. (2018) 建议用神经网络 vt(x;θ)v_t(x;θ)vt(x;θ) 对矢量场 vtv_tvt 进行建模,其中 θ∈Rpθ ∈ \mathbb R^pθRp 是其可学习参数,从而得到一个关于流 ϕt\phi_tϕt 的深度参数模型,称为 Continuous Normalizing Flow (CNF)。CNF 用于通过前推方程将简单的先验密度 p0p_0p0(例如纯噪声)重塑为更复杂的密度 p1p_1p1

pt=[ϕt]∗p0(3)p_t=[\phi_t]_*p_0\tag{3}pt=[ϕt]p0(3)

其中前推(或变量改变)运算符 ∗∗ 定义为

[ϕt]∗p0(x)=p0(ϕt−1(x))det[∂ϕt−1∂x(x)](4)[\phi_t]_*p_0(x)=p_0(\phi^{-1}_t(x))det[\frac{\partial \phi^{-1}_t}{\partial x}(x)]\tag{4}[ϕt]p0(x)=p0(ϕt1(x))det[xϕt1(x)](4)

如果矢量场 vtv_tvt 的流 ϕt\phi_tϕt 满足方程 3,则称该矢量场生成概率密度路径 ptp_tpt。测试矢量场是否生成概率路径的一种实用方法是使用连续性方程,它是我们证明中的关键组成部分,请参阅附录 B。我们在附录 C 中回顾了有关 CNF 的更多信息,特别是如何计算任意点 x∈Rdx ∈ \mathbb R^dxRd 处的概率 p1(x)p_1(x)p1(x)

3.FLOW MATCHING

x1x_1x1 表示服从某个未知数据分布 q(x1)q(x_1)q(x1) 的随机变量。我们假设我们只能访问 q(x1)q(x_1)q(x1) 中的数据样本,但无法访问其密度函数本身。此外,我们令 ptp_tpt 为一条概率路径,其中 p0=pp_0 = pp0=p 为简单分布,例如标准正态分布 p(x)=N(x∣0,I)p(x) = \mathcal N(x|0, I)p(x)=N(x∣0,I),并p1p_1p1 的分布与 qqq 近似相等。稍后我们将讨论如何构建这样的路径。Flow Matching 目标旨在匹配该目标概率路径,从而使我们能够从 p0p_0p0 流向 p1p_1p1

给定目标概率密度路径 pt(x)p_t(x)pt(x) 和相应的向量场 ut(x)u_t(x)ut(x)(生成 pt(x)p_t(x)pt(x)),我们将 Flow Matching (FM) 目标定义为

LFM(θ)=Et,pt(x)∣∣vt(x)−ut(x)∣∣2,(5)\mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(x)}||v_t(x)-u_t(x)||^2,\tag{5}LFM(θ)=Et,pt(x)∣∣vt(x)ut(x)2,(5)

其中 θθθ 表示 CNF 向量场 vtv_tvt(定义见第二节)的可学习参数,t∼U[0,1]t ∼ \mathcal U[0, 1]tU[0,1](均匀分布),x∼pt(x)x ∼ p_t(x)xpt(x)。简而言之,FM 损失利用神经网络 vtv_tvt 对向量场 utu_tut 进行回归。当损失达到零时,学习到的 CNF 模型将生成 pt(x)p_t(x)pt(x)

Flow Matching 是一个简单且有吸引力的目标,但就其本身而言,由于我们事先不知道合适的 ptp_tptutu_tut 是什么,因此在实践中很难应用。有很多概率路径可以满足 p1(x)≈q(x)p_1(x) ≈ q(x)p1(x)q(x) ,更重要的是,我们通常无法获得生成所需 ptp_tpt 的闭式 utu_tut。在本节中,我们将展示如何使用仅针对每个样本定义的概率路径和矢量场来构建 ptp_tptutu_tut,并且适当的聚合方法可以提供所需的 ptp_tptutu_tut。此外,这种构造使我们能够为 Flow Matching 创建一个更易于处理的目标。

3.1 CONSTRUCTING pt,utp_t, u_tpt,ut FROM CONDITIONAL PROBABILITY PATHS AND VECTOR FIELDS

构建目标概率路径的一种简单方法是通过混合更简单的概率路径:给定一个特定的数据样本 x1x_1x1,我们用 pt(x∣x1)p_t(x|x_1)pt(xx1) 表示一个条件概率路径,使得它在时间 t=0t = 0t=0 时满足 p0(x∣x1)=p(x)p_0(x|x_1) = p(x)p0(xx1)=p(x),并且我们将 t=1t = 1t=1 时的 p1(x∣x1)p_1(x|x_1)p1(xx1) 设计为集中在 x=x1x = x_1x=x1 周围的分布,例如,p1(x∣x1)=N(x∣x1,σ2I)p_1(x|x_1) = N(x|x_1, σ^2I)p1(xx1)=N(xx1,σ2I),这是一个具有 x1x_1x1 均值和足够小的标准差 σ>0σ > 0σ>0 的正态分布。对 q(x1)q(x_1)q(x1) 上的条件概率路径进行边缘化,得到边缘概率路径:

pt(x)=∫pt(x∣x1)q(x1)dx1,(6)p_t(x)=\int p_t(x|x_1)q(x_1)dx_1,\tag{6}pt(x)=pt(xx1)q(x1)dx1,(6)

其中,特别是在时间 t=1t = 1t=1 时,边缘概率 p1p_1p1 是一个混合分布,与数据分布 q 非常接近,

p1(x)=∫p1(x∣x1)q(x1)dx1≈q(x).(7)p_1(x)=\int p_1(x|x_1)q(x_1)dx_1\approx q(x).\tag{7}p1(x)=p1(xx1)q(x1)dx1q(x).(7)

有趣的是,我们还可以定义一个边缘矢量场,通过以以下方式对条件矢量场进行“边缘化”(我们假设对于所有 tttxxxpt(x)>0p_t(x) > 0pt(x)>0):

ut(x)=∫ut(x∣x1)pt(x∣x1)q(x1)pt(x)dx1,(8)u_t(x)=\int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)}dx_1,\tag{8}ut(x)=ut(xx1)pt(x)pt(xx1)q(x1)dx1,(8)

其中 ut(⋅∣x1):Rd→Rdu_t(·|x_1) : \mathbb R^d → \mathbb R^dut(x1):RdRd 是一个生成 pt(⋅∣x1)p_t(·|x_1)pt(x1) 的条件向量场。这看起来可能不太明显,但这种聚合条件向量场的方式实际上可以得到用于建模边缘概率路径的正确向量场。

我们的第一个关键观察是:

The marginal vector field (equation 8) generates the marginal probability path (equation 6).

这在条件概率函数(生成条件概率路径的函数)和边缘概率函数(生成边缘概率路径的函数)之间建立了一种令人惊讶的联系。这种联系使我们能够将未知且难以处理的边缘概率函数分解为更简单的条件概率函数,由于它们仅依赖于单个数据样本,因此定义起来要简单得多。我们将此定理形式化为以下定理。

Theorem 1。给定生成条件概率路径 pt(x∣x1)p_t(x|x_1)pt(xx1) 的向量场 ut(x∣x1)u_t(x|x_1)ut(xx1),对于任意分布 q(x1)q(x_1)q(x1),公式 8 中的边缘向量场 utu_tut 能够生成公式 6 中的边缘概率路径 ptp_tpt,即 utu_tutptp_tpt 满足连续性方程(公式 26)。

我们定理的完整证明均在附录 A 中提供。定理 1 也可以从 Peluchetti (2021) 中的扩散混合表示定理中推导出来,该定理提供了扩散 SDE 中的边缘漂移和扩散系数的公式。

3.2 CONDITIONAL FLOW MATCHING

遗憾的是,由于边缘概率路径和 VF(公式 6 和 8)定义中存在难以处理的积分,因此仍然难以计算 utu_tut,进而难以简单地计算原始 Flow Matching 目标的无偏估计量。为此,我们提出了一个更简单的目标,令人惊讶的是,它将产生与原始目标相同的最优值。具体而言,我们考虑 Conditional Flow Matching (CFM) 目标,

LCFM(θ)=Et,q(x1),pt(x∣x1)∣∣vt(x)−ut(x∣x1)∣∣2,(9)\mathcal L_{CFM}(\theta)=\mathbb E_{t,q(x_1),p_t(x|x_1)}||v_t(x)-u_t(x|x_1)||^2,\tag{9}LCFM(θ)=Et,q(x1),pt(xx1)∣∣vt(x)ut(xx1)2,(9)

其中 t∼U[0,1],x1∼q(x1)t ∼ \mathcal U[0, 1], x_1 ∼ q(x_1)tU[0,1],x1q(x1),现在 x∼pt(x∣x1)x ∼ p_t(x|x_1)xpt(xx1)。与 FM 目标函数不同,CFM 目标函数允许我们轻松地采样无偏估计,只要我们能够高效地从 pt(x∣x1)p_t(x|x_1)pt(xx1) 采样并计算 ut(x∣x1)u_t(x|x_1)ut(xx1) 即可。由于这两个计算都是基于每个样本定义的,因此很容易完成。因此,我们的第二个关键观察结果是:

The FM (equation 5) and CFM (equation 9) objectives have identical gradients w.r.t. θ

也就是说,优化 CFM 目标函数(在期望上)等价于优化 FM 目标函数。因此,我们能够训练一个 CNF 来生成边缘概率路径 ptp_tpt(具体来说,它近似于 t=1t=1t=1 时刻的未知数据分布 qqq),而无需直接访问边缘概率路径或边际向量场。我们只需设计合适的条件概率路径和向量场即可。我们将此性质形式化为以下定理。

Theorem 2。假设对于所有 x∈Rdx ∈ \mathbb R^dxRdt∈[0,1],pt(x)>0t ∈ [0, 1], p_t(x) > 0t[0,1],pt(x)>0,那么,在与 θ 无关的常数之前,LCFM\mathcal L_{CFM}LCFMLFM\mathcal L_{FM}LFM 相等。因此,∇θLFM(θ)=∇θLCFM(θ)∇_θ\mathcal L_{FM}(θ) = ∇_θ\mathcal L_{CFM}(θ)θLFM(θ)=θLCFM(θ)

4.CONDITIONAL PROBABILITY PATHS AND VECTOR FIELDS

Conditional Flow Matching 目标适用于任何条件概率路径和条件向量场的选择。在本节中,我们讨论高斯条件概率路径中的 pt(x∣x1)p_t(x|x_1)pt(xx1)ut(x∣x1)u_t(x|x_1)ut(xx1) 的构造。也就是说,我们考虑如下形式的条件概率路径

pt(x∣x1)=N(x∣µt(x1),σt(x1)2I),(10)p_t(x|x_1)=\mathcal N(x|µ_t(x_1),\sigma_t(x_1)^2I),\tag{10}pt(xx1)=N(xµt(x1),σt(x1)2I),(10)

其中 µ:[0,1]×Rd→Rdµ: [0, 1] × \mathbb R^d → \mathbb R^dµ:[0,1]×RdRd 是高斯分布的时间相关均值,而 σ:[0,1]×R→R>0σ : [0, 1] × \mathbb R → \mathbb R_{>}0σ:[0,1]×RR>0 描述时间相关的标量标准差 (std)。我们设 µ0(x1)=0µ_0(x_1) = 0µ0(x1)=0σ0(x1)=1σ_0(x_1) = 1σ0(x1)=1,使得所有条件概率路径在 t=0t = 0t=0 时收敛到相同的标准高斯噪声分布,p(x)=N(x∣0,I)p(x) = \mathcal N(x|0, I)p(x)=N(x∣0,I)。然后,我们设 µ1(x1)=x1µ_1(x_1) = x_1µ1(x1)=x1σ1(x1)=σminσ_1(x_1) = σ_{min}σ1(x1)=σmin,该值设置得足够小,使得 p1(x∣x1)p_1(x|x_1)p1(xx1) 是以 x_1 为中心的集中高斯分布。

生成任意特定概率路径的矢量场数量是无穷无尽的(例如,通过在连续性方程中添加一个不散度的分量,参见公式 26),但其中绝大多数是由于存在使底层分布保持不变的分量(例如,当分布具有旋转不变性时,旋转分量的存在)导致不必要的额外计算。我们决定使用对应于高斯分布正则变换的最简单的矢量场。具体来说,考虑流(以 x1x_1x1 为条件)

ψt(x)=σt(x1)x+μt(x1).(11)ψ_t(x)=\sigma_t(x_1)x+\mu_t(x_1).\tag{11}ψt(x)=σt(x1)x+μt(x1).(11)

xxx 服从标准高斯分布时,ψt(x)ψ_t(x)ψt(x) 是仿射变换,它映射到一个均值为 µt(x1)µ_t(x_1)µt(x1) 且标准差为 σt(x1)σ_t(x_1)σt(x1) 的正态分布随机变量。也就是说,根据公式 4,ψtψ_tψt 将噪声分布 p0(x∣x1)=p(x)p_0(x|x_1) = p(x)p0(xx1)=p(x) 推到 pt(x∣x1)p_t(x|x_1)pt(xx1),即:

[ψt]∗p(x)=pt(x∣x1).(12)[ψ_t]_*p(x)=p_t(x|x_1).\tag{12}[ψt]p(x)=pt(xx1).(12)

然后,该流提供一个生成条件概率路径的矢量场(公式1):

ddtψt(x)=ut(ψt(x)∣x1).(13)\frac{d}{dt}ψ_t(x)=u_t(ψ_t(x)|x_1).\tag{13}dtdψt(x)=ut(ψt(x)x1).(13)

仅用 x0x_0x0 重新参数化 pt(x∣x1)p_t(x|x_1)pt(xx1),并将公式 13 代入 CFM 损失中,我们得到

LCFM=Et,q(x1),p(x0)∣∣vt(ψt(x0))−ddtψt(x0)∣∣2.(14)\mathcal L_{CFM}=\mathbb E_{t,q(x_1),p(x_0)}||v_t(ψ_t(x_0))-\frac{d}{dt}ψ_t(x_0)||^2.\tag{14}LCFM=Et,q(x1),p(x0)∣∣vt(ψt(x0))dtdψt(x0)2.(14)

由于 ψtψ_tψt 是一个简单的(可逆的)仿射映射,我们可以使用公式13以闭合形式求解 utu_tut。令 f′f'f 表示关于时间的导数,即对于与时间相关的函数 ffff′=ddtff' =\frac{d}{dt}ff=dtdf

Theorem 3。令 pt(x∣x1)p_t(x|x_1)pt(xx1) 为高斯概率路径,如公式 10 所示,ψtψ_tψt 为其对应的流图,如公式 11 所示。那么,定义 ψtψ_tψt 的唯一矢量场具有以下形式:

ut(x∣x1)=σt′(x1)σt(x1)(x−μt(x1))+μt′(x1).(15)u_t(x|x_1)=\frac{\sigma'_t(x_1)}{\sigma_t(x_1)}(x-\mu_t(x_1))+\mu'_t(x_1).\tag{15}ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1).(15)

因此,ut(x∣x1)u_t(x|x_1)ut(xx1) 生成高斯路径 pt(x∣x1)p_t(x|x_1)pt(xx1)

4.1 SPECIAL INSTANCES OF GAUSSIAN CONDITIONAL PROBABILITY PATHS

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值