扩散模型奠基与DDPM

扩散模型奠基与DDPM

前向过程

  1. 数据分布:q(x(0))q(x^{(0)})q(x(0))
  2. 目标分布:一个容易解析的简单分布π(x(T))\pi\left( x^{(T)} \right)π(x(T))
  3. 马尔可夫内核TπT_{\pi}Tπ:一个一般性的马尔可夫矩阵,将当前状态映射到下一个状态的概率

q(x(t)∣x(t−1))=Tπ(x(t)∣x(t−1);βt)q(x(0⋯T))=q(x(0))∏t=1Tq(x(t)∣x(t−1)) q\left(x^{(t)}|x^{(t-1)}\right) = T_{\pi}\left( x^{(t)}|x^{(t-1)};\beta_{t} \right) \\ q\left( x^{(0\cdots T)} \right) = q\left(x^{(0)}\right)\prod_{t=1}^{T}{q\left( x^{(t)}|x^{(t-1)} \right)} q(x(t)x(t1))=Tπ(x(t)x(t1);βt)q(x(0T))=q(x(0))t=1Tq(x(t)x(t1))

反向过程

  1. 原始分布:简单分布π(x(T))\pi\left( x^{(T)} \right)π(x(T))
  2. 目标分布:数据分布

p(x(T))=π(x(T))p(x(0⋯T))=p(x(T))∏t=1Tp(x(t−1)∣x(t)) p\left( x^{(T)} \right) = \pi\left( x^{(T)} \right) \\ p\left( x^{(0\cdots T)} \right) = p\left(x^{(T)}\right)\prod_{t=1}^{T}{p\left( x^{(t-1)}|x^{(t)} \right)} p(x(T))=π(x(T))p(x(0T))=p(x(T))t=1Tp(x(t1)x(t))

关键假设:当每一次加噪的幅度非常小,如果正向过程是高斯分布(或者二项分布),则反向过程也是高斯分布(或者二项分布)因此我们可以直接根据非常少的参数就能够把反向的过程确定下来。

详解反向过程

p(x(0⋯T))=p(x(T))∏t=1Tp(x(t−1)∣x(t))p(x(0))=∫dx(1)∫dx(2)⋯∫x(T)p(x(0⋯T))=∫dx(1⋯T)p(x(0⋯T))q(x(1⋯T)∣x(0))q(x(1⋯T)∣x(0))=∫dx(1⋯T)q(x(1⋯T)∣x(0))p(x(0⋯T))q(x(1⋯T)∣x(0))=∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1)) \begin{aligned} p\left( x^{(0\cdots T)} \right) &= p\left(x^{(T)}\right)\prod_{t=1}^{T}{p\left( x^{(t-1)}|x^{(t)} \right)} \\[5pt] p\left( x^{(0)} \right) &= \int{dx^{(1)}\int{dx^{(2)}\cdots\int{x^{(T)}p\left( x^{(0\cdots T)} \right)}}} \\[5pt] &=\int{dx^{(1\cdots T)}p\left( x^{(0\cdots T)} \right)\frac{q\left( x^{(1\cdots T)}|x^{(0)} \right)}{q\left( x^{(1\cdots T)}|x^{(0)} \right)}} \\[5pt] &=\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\frac{p\left( x^{(0\cdots T)} \right)}{q\left( x^{(1\cdots T)}|x^{(0)} \right)}} \\[5pt] &=\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\cdot p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}} \end{aligned} p(x(0T))p(x(0))=p(x(T))t=1Tp(x(t1)x(t))=dx(1)dx(2)x(T)p(x(0T))=dx(1T)p(x(0T))q(x(1T)x(0))q(x(1T)x(0))=dx(1T)q(x(1T)x(0))q(x(1T)x(0))p(x(0T))=dx(1T)q(x(1T)x(0))p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))

对数似然及其ELBO

现在我们已经有了p(x(0))p\left(x^{(0)}\right)p(x(0))的一个比较好的解析方法,因此可以用来比较模型的输出值p(x(0))p\left(x^{(0)}\right)p(x(0))和真实数据q(x(0))q\left(x^{(0)}\right)q(x(0))之间的匹配程度,用对数似然函数即可:
L=∫dx(0)q(x(0))log⁡p(x(0))=∫dx(0)q(x(0))log⁡[∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))] \begin{aligned} L &= \int{dx^{(0)}q\left(x^{(0)}\right)\log{p\left( x^{(0)} \right)}} \\[5pt] &= \int{dx^{(0)}q\left(x^{(0)}\right)\log{\left[\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\cdot p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}}\right]}} \end{aligned} L=dx(0)q(x(0))logp(x(0))=dx(0)q(x(0))log[dx(1T)q(x(1T)x(0))p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))]
由Jenson不等式可知,对于任意的上凸函数,都有f[E(x)]≥E[f(x)]f[E(x)]\ge E[f(x)]f[E(x)]E[f(x)],而log⁡\loglog函数就是一个上凸函数,因此我们可以得到对数似然函数的下界:
L=∫dx(0)q(x(0))log⁡Ex(1⋯T)∼q(x(1⋯T)∣x(0))[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]≥∫dx(0)q(x(0))Ex(1⋯T)∼q(x(1⋯T)∣x(0))log⁡[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0)dx(1⋯T)q(x(0))q(x(1⋯T)∣x(0))log⁡[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0⋯T)q(x(0⋯T))[∑t=1Tlog⁡p(x(t−1)∣x(t))q(x(t)∣x(t−1))+log⁡p(x(T))]=K \begin{aligned} L &=\int{dx^{(0)}q\left(x^{(0)}\right)\log{E_{x^{(1\cdots T)}\sim q\left( x^{(1\cdots T)}|x^{(0)} \right)}\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}} \\[5pt] &\ge \int{dx^{(0)}q\left(x^{(0)}\right)E_{x^{(1\cdots T)}\sim q\left( x^{(1\cdots T)}|x^{(0)} \right)}\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}} \\[5pt] &=\int{dx^{(0)}dx^{(1\cdots T)}q\left( x^{(0)} \right)q\left( x^{(1\cdots T)}|x^{(0)} \right)\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}} \\[5pt] &=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}} \\[5pt] &=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\left[\sum_{t=1}^{T}\log{\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}}+\log{p\left( x^{(T)} \right)}\right]} \\[5pt] &= K \end{aligned} L=dx(0)q(x(0))logEx(1T)q(x(1T)x(0))[p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))]dx(0)q(x(0))Ex(1T)q(x(1T)x(0))log[p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))]=dx(0)dx(1T)q(x(0))q(x(1T)x(0))log[p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))]=dx(0T)q(x(0T))log[p(x(T))t=1Tq(x(t)x(t1))p(x(t1)x(t))]=dx(0T)q(x(0T))[t=1Tlogq(x(t)x(t1))p(x(t1)x(t))+logp(x(T))]=K
接下来我们可以对这个下界KKK继续进行处理:由(1)(1)(1)式可知,我们可以把q(x(0⋯T))q\left( x^{(0\cdots T)} \right)q(x(0T))转换为q(x(0))∏t=1Tq(x(t)∣x(t−1))q\left(x^{(0)}\right)\prod_{t=1}^{T}{q\left( x^{(t)}|x^{(t-1)} \right)}q(x(0))t=1Tq(x(t)x(t1)),后面的条件概率刚好可以和前面的积分抵消形成边缘概率分布,最终得到:
K=∫dx(0⋯T)q(x(0⋯T))∑t=1Tlog⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(T)q(x(T))log⁡p(x(T))=∫dx(0⋯T)q(x(0⋯T))∑t=1Tlog⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(T)π(x(T))log⁡π(x(T))=∑t=1T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]−Hp(X(T)) \begin{aligned} K &=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\sum_{t=1}^{T}\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}+\int{dx^{(T)}q\left( x^{(T)} \right)\log{p\left( x^{(T)} \right)}} \\[5pt] &=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\sum_{t=1}^{T}\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}+\int{dx^{(T)}\pi\left( x^{(T)} \right)\log{\pi\left( x^{(T)} \right)}} \\[5pt] &=\sum_{t=1}^{T}\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}-H_{p}\left( X^{(T)} \right) \end{aligned} K=dx(0T)q(x(0T))t=1Tlog[q(x(t)x(t1))p(x(t1)x(t))]+dx(T)q(x(T))logp(x(T))=dx(0T)q(x(0T))t=1Tlog[q(x(t)x(t1))p(x(t1)x(t))]+dx(T)π(x(T))logπ(x(T))=t=1Tdx(0T)q(x(0T))log[q(x(t)x(t1))p(x(t1)x(t))]Hp(X(T))
接下来我们要考虑边界点,这需要进行单独考虑,就好比抓住绳子的两端之后要并在一起才能把绳子完美地对折:
p(x(0)∣x(1))=q(x(1)∣x(0))π(x(0))π(x(1))=Tπ(x(0)∣x(1);β1) p\left( x^{(0)}|x^{(1)} \right) = q\left( x^{(1)}|x^{(0)} \right)\frac{\pi\left( x^{(0)} \right)}{\pi\left( x^{(1)} \right)} = T_{\pi}\left( x^{(0)}|x^{(1)};\beta_{1} \right) p(x(0)x(1))=q(x(1)x(0))π(x(1))π(x(0))=Tπ(x(0)x(1);β1)

K=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log⁡[p(x(0)∣x(1))q(x(1)∣x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log⁡[q(x(1)∣x(0))π(x(0))q(x(1)∣x(0))π(x(1))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log⁡[π(x(0))π(x(1))]−Hp(X(T)) \begin{aligned} K&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{p\left( x^{(0)}|x^{(1)} \right)}{q\left( x^{(1)}|x^{(0)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{q\left( x^{(1)}|x^{(0)} \right)\pi\left( x^{(0)} \right)}{q\left( x^{(1)}|x^{(0)} \right)\pi\left( x^{(1)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{\pi\left( x^{(0)} \right)}{\pi\left( x^{(1)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \end{aligned} K=t=2Tdx(0T)q(x(0T))log[q(x(t)x(t1))p(x(t1)x(t))]+dx(0)dx(1)q(x(0),x(1))log[q(x(1)x(0))p(x(0)x(1))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t)x(t1))p(x(t1)x(t))]+dx(0)dx(1)q(x(0),x(1))log[q(x(1)x(0))π(x(1))q(x(1)x(0))π(x(0))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t)x(t1))p(x(t1)x(t))]+dx(0)dx(1)q(x(0),x(1))log[π(x(1))π(x(0))]Hp(X(T))

由于前向加噪的初始步长非常小,实际上通过积分得到了两个Entropy可以直接消去,最终得到:
K=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t)∣x(t−1),x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))q(x(t−1)∣x(0))q(x(t)∣x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]+∑t=2T[Hq(X(t)∣X(0))−Hq(X(t−1)∣X(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T)) \begin{aligned} K &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)},x^{(0)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\frac{q\left( x^{(t-1)}|x^{(0)} \right)}{q\left( x^{(t)}|x^{(0)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}}+\sum_{t=2}^{T}\left[ H_q\left( X^{(t)}|X^{(0)} \right) - H_q\left( X^{(t-1)}|X^{(0)} \right) \right]-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}}+ H_q\left( X^{(T)}|X^{(0)} \right) - H_q\left( X^{(1)}|X^{(0)} \right) - H_{p}\left( X^{(T)} \right) \end{aligned} K=t=2Tdx(0T)q(x(0T))log[q(x(t)x(t1))p(x(t1)x(t))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t)x(t1),x(0))p(x(t1)x(t))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t1)x(t),x(0))p(x(t1)x(t))q(x(t)x(0))q(x(t1)x(0))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]+t=2T[Hq(X(t)X(0))Hq(X(t1)X(0))]Hp(X(T))=t=2Tdx(0T)q(x(0T))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]+Hq(X(T)X(0))Hq(X(1)X(0))Hp(X(T))
接下来处理前面的大块头:由于log⁡\loglog函数里面只和x(t),x(t−1),x(0)x^{(t)},x^{(t-1)},x^{(0)}x(t),x(t1),x(0)有关,所以对于每一项积分,我们都可以把其他的积分全部消掉变成1:
∫dx(0⋯T)q(x(0⋯T))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t−1)dx(t)q(x(0))q(x(t−1))q(x(t))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t−1)dx(t)q(x(0),x(t))q(x(t−1)∣x(t),x(0))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t)q(x(0),x(t))∫dx(t−1)q(x(t−1)∣x(t),x(0))log⁡[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=−∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t))) \begin{aligned} &\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}} \\[5pt] =&\int{dx^{(0)}dx^{(t-1)}dx^{(t)}q\left( x^{(0)} \right)q\left( x^{(t-1)} \right)q\left( x^{(t)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}} \\[5pt] =&\int{dx^{(0)}dx^{(t-1)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}} \\[5pt] =&\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)\int{dx^{(t-1)}q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}} \\[5pt] =&-\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)} \end{aligned} ====dx(0T)q(x(0T))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]dx(0)dx(t1)dx(t)q(x(0))q(x(t1))q(x(t))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]dx(0)dx(t1)dx(t)q(x(0),x(t))q(x(t1)x(t),x(0))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]dx(0)dx(t)q(x(0),x(t))dx(t1)q(x(t1)x(t),x(0))log[q(x(t1)x(t),x(0))p(x(t1)x(t))]dx(0)dx(t)q(x(0),x(t))DKL(q(x(t1)x(t),x(0))p(x(t1)x(t)))
因此我们可以用KLKLKL散度来表示KKK
K=−∑t=2T∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T)) \begin{aligned} K =&-\sum_{t=2}^{T}\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)} \\ &+ H_q\left( X^{(T)}|X^{(0)} \right) - H_q\left( X^{(1)}|X^{(0)} \right) - H_{p}\left( X^{(T)} \right) \end{aligned} K=t=2Tdx(0)dx(t)q(x(0),x(t))DKL(q(x(t1)x(t),x(0))p(x(t1)x(t)))+Hq(X(T)X(0))Hq(X(1)X(0))Hp(X(T))
由于后面几个熵的值都是常数,我们只能尝试通过最小化DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)DKL(q(x(t1)x(t),x(0))p(x(t1)x(t)))使得KKK尽可能达到最大,因此LLL的下界尽可能增大。

前向扩散核、反向扩散核

在前向扩散的时候,我们可以自定义加噪公式:
q(x(t)∣x(t−1))=N(x(t);1−βt⋅x(t−1),Iβt)x(t)=1−βt⋅x(t−1)+βt⋅ϵt,  ϵt∼N(0,I) q\left(x^{(t)}|x^{(t-1)}\right) =\mathcal{N}\left( x^{(t)};\sqrt{1-\beta_{t}}\cdot x^{(t-1)},\mathbf{I}\beta_{t} \right) \\[5pt] x^{(t)} = \sqrt{1-\beta_t}\cdot x^{(t-1)}+\beta_t\cdot\epsilon_{t},\;\epsilon_{t}\sim\mathcal{N}(0,\mathbf{I}) q(x(t)x(t1))=N(x(t);1βtx(t1),Iβt)x(t)=1βtx(t1)+βtϵt,ϵtN(0,I)
这里的βt\beta_{t}βt是自定义的,你可以选择线性增长或者三角函数型增长。通过每一步加噪,我们将分布的均值逐渐往0逼近,并且方差越来越大,最后形成一个N(x(T);0,I)\mathcal{N}\left( x^{(T)};0,\mathbf{I} \right)N(x(T);0,I)的标准正态分布。

反向过程中需要学习两个参数:
p(x(t−1)∣x(t))=N(x(t−1);fμ(x(t);t),fΣ(x(t);t)) p\left( x^{(t-1)}|x^{(t)} \right) = \mathcal{N}\left( x^{(t-1)};f_{\mu}\left( x^{(t)};t \right),f_{\Sigma}\left( x^{(t)};t \right) \right) p(x(t1)x(t))=N(x(t1);fμ(x(t);t),fΣ(x(t);t))
即每一次反向过程中的均值和方差。

简化表达 (DDPM开始发力)

1. α\alphaα的引入

很显然,如果我们真的按照每一步加噪、去噪来训练模型的话,那么模型训练的时间相当长,因为每一次我都要以上一次加噪(去噪)的结果为输入,因此在前面的时间步进行训练的时候,后面的时间步只能等待。

我们发现从x(0)x^{(0)}x(0)x(t)x^{(t)}x(t)的累积过程可以递归展开:
x(1)=1−β1⋅x(0)+β1⋅ϵ1x(2)=1−β2⋅x(1)+β2⋅ϵ2=1−β2⋅(1−β1⋅x(0)+β1⋅ϵ1)+β2⋅ϵ2=(1−β2)(1−β1)⋅x(0)+(1−β2)β1⋅ϵ1+β2⋅ϵ2 \begin{aligned} x^{(1)} &= \sqrt{1-\beta_1}\cdot x^{(0)}+\beta_1\cdot\epsilon_{1} \\[10pt] x^{(2)} &= \sqrt{1-\beta_2}\cdot x^{(1)}+\beta_2\cdot\epsilon_{2} \\[5pt] &= \sqrt{1-\beta_2}\cdot\left( \sqrt{1-\beta_1}\cdot x^{(0)}+\beta_1\cdot\epsilon_{1} \right)+\beta_2\cdot\epsilon_{2} \\[5pt] &=\sqrt{(1-\beta_2)(1-\beta_1)}\cdot x^{(0)}+\sqrt{(1-\beta_2)\beta_1}\cdot\epsilon_{1}+\sqrt{\beta_2}\cdot\epsilon_2 \end{aligned} x(1)x(2)=1β1x(0)+β1ϵ1=1β2x(1)+β2ϵ2=1β2(1β1x(0)+β1ϵ1)+β2ϵ2=(1β2)(1β1)x(0)+(1β2)β1ϵ1+β2ϵ2
由于ϵ1,ϵ2\epsilon_1,\epsilon_2ϵ1,ϵ2相互独立并且服从N(0,I)\mathcal{N}(0,\mathbf{I})N(0,I),因此我们可以把他们俩合并成一个高斯分布:
((1−β2)β1⋅ϵ1+β2⋅ϵ2)∼N(0,[(1−β2)β1+β2]I)=N(0,[1−(1−β1)(1−β2)]I) \begin{aligned} \left(\sqrt{(1-\beta_2)\beta_1}\cdot\epsilon_{1}+\sqrt{\beta_2}\cdot\epsilon_2\right)&\sim\mathcal{N}\left( 0,[(1-\beta_2)\beta_1+\beta_2]\mathbf{I} \right) \\[5pt] &= \mathcal{N}(0,[1-(1-\beta_1)(1-\beta_2)]\mathbf{I}) \end{aligned} ((1β2)β1ϵ1+β2ϵ2)N(0,[(1β2)β1+β2]I)=N(0,[1(1β1)(1β2)]I)
接下来就可以推广到ttt步:定义αt=1−βt,  αˉt=∏s=1tαs\alpha_t = 1-\beta_t,\;\bar\alpha_t = \prod_{s=1}^{t}\alpha_sαt=1βt,αˉt=s=1tαs,则x(t)x^{(t)}x(t)可以表示为:
x(t)=αˉt⋅x(0)+1−αˉt⋅ϵ,  ϵ∼N(0,I)q(x(t)∣x(0))=N(x(t);αˉt⋅x(0),(1−αˉt)I) x^{(t)} =\sqrt{\bar\alpha_t}\cdot x^{(0)}+\sqrt{1-\bar\alpha_t}\cdot\epsilon,\;\epsilon\sim\mathcal{N}(0,\mathbf{I}) \\[5pt] q\left( x^{(t)}|x^{(0)} \right) = \mathcal{N}\left( x^{(t)};\sqrt{\bar\alpha_t}\cdot x^{(0)},(1-\bar\alpha_t)\mathbf{I} \right) x(t)=αˉtx(0)+1αˉtϵ,ϵN(0,I)q(x(t)x(0))=N(x(t);αˉtx(0),(1αˉt)I)
这个表达式仅依赖于αˉt\bar\alpha_tαˉt,即αt\alpha_tαt的累积乘积,避免了逐步计算噪声项的复杂性。

2. DKLD_{KL}DKL在高斯分布条件下的简化与损失函数

上文已经给出,我们的目标是最小化DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p_{\theta}\left( x^{(t-1)}|x^{(t)} \right) \right)DKL(q(x(t1)x(t),x(0))pθ(x(t1)x(t)))使得KKK尽可能达到最大,其中θ\thetaθ表示神经网络的参数

对于两个高斯分布N(μ1;σ2I),N(μ2;σ2I)\mathcal{N}(\mu_1;\sigma^{2}\mathbf{I}),\mathcal{N}(\mu_2;\sigma^{2}\mathbf{I})N(μ1;σ2I),N(μ2;σ2I),他们之间的DKLD_{KL}DKL可以用以下方式表示:
DKL(N(μ1;σ2I)∥N(μ2;σ2I))=12σ2∥μ1−μ2∥2 D_{KL}\left( \mathcal{N}(\mu_1;\sigma^{2}\mathbf{I})\parallel\mathcal{N}(\mu_2;\sigma^{2}\mathbf{I}) \right) = \frac{1}{2\sigma^2}\left\| \mu_1-\mu_2 \right\|^{2} DKL(N(μ1;σ2I)N(μ2;σ2I))=2σ21μ1μ22
接下来我们需要计算前向过程的真实后验分布,即q(x(t−1)∣x(t),x(0))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right)q(x(t1)x(t),x(0)),它也遵循高斯分布,不过比较复杂:
q(x(t−1)∣x(t),x(0))=N(x(t−1);μ~t(x(t),x(0)),β~tI) q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = \mathcal{N}\left( x^{(t-1)};\tilde{\mu}_{t}\left( x^{(t)},x^{(0)} \right),\tilde{\beta}_{t}\mathbf{I} \right) q(x(t1)x(t),x(0))=N(x(t1);μ~t(x(t),x(0)),β~tI)

在这里,均值和方差都有解析解,接下来推导其过程:

目标是求q(x(t−1)∣x(t),x(0))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right)q(x(t1)x(t),x(0)),即已知x(t)x^{(t)}x(t)x(0)x^{(0)}x(0)时,x(t−1)x^{(t-1)}x(t1)的条件分布。

根据贝叶斯定义:
q(x(t−1)∣x(t),x(0))=q(x(t)∣x(t−1),x(0))⋅q(x(t−1)∣x(0))q(x(t)∣x(0)) q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = \frac{q\left( x^{(t)}|x^{(t-1)},x^{(0)}\right)\cdot q\left( x^{(t-1)}|x^{(0)} \right)}{q\left( x^{(t)}|x^{(0)} \right)} q(x(t1)x(t),x(0))=q(x(t)x(0))q(x(t)x(t1),x(0))q(x(t1)x(0))
由于前向过程是马尔可夫链,q(x(t−1)∣x(t),x(0))=q(x(t−1)∣x(t))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = q\left( x^{(t-1)}|x^{(t)}\right)q(x(t1)x(t),x(0))=q(x(t1)x(t)),因此:
q(x(t−1)∣x(t),x(0))∝q(x(t)∣x(t−1))⋅q(x(t−1)∣x(0)) q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) \propto q\left( x^{(t)}|x^{(t-1)}\right)\cdot q\left( x^{(t-1)}|x^{(0)}\right) q(x(t1)x(t),x(0))q(x(t)x(t1))q(x(t1)x(0))
接下来把各个高斯分布显式地写出:

  1. 前向单步转移概率:

q(x(t)∣x(t−1))∝exp⁡(−∥x(t)−αtx(t−1)∥22βt) q\left( x^{(t)}|x^{(t-1)}\right) \propto \exp{\left( -\frac{\left\| x^{(t)}-\sqrt{\alpha_t}x^{(t-1)} \right\|^{2}}{2\beta_t} \right)} q(x(t)x(t1))exp(2βtx(t)αtx(t1)2)

  1. x(0)x^{(0)}x(0)x(t−1)x^{(t-1)}x(t1)的累积分布:

q(x(t−1)∣x(0))∝exp⁡(−∥x(t−1)−αˉt−1x(0)∥22(1−αˉt−1)) q\left( x^{(t-1)}|x^{(0)}\right)\propto\exp{\left( -\frac{\left\| x^{(t-1)}-\sqrt{\bar\alpha_{t-1}}x^{(0)} \right\|^{2}}{2(1-\bar\alpha_{t-1})} \right)} q(x(t1)x(0))exp(2(1αˉt1)x(t1)αˉt1x(0)2)

  1. 联合分布的指数项,即后验分布的指数项为:

−∥x(t)−αtx(t−1)∥22βt−∥x(t−1)−αˉt−1x(0)∥22(1−αˉt−1) -\frac{\left\| x^{(t)}-\sqrt{\alpha_t}x^{(t-1)} \right\|^{2}}{2\beta_t} -\frac{\left\| x^{(t-1)}-\sqrt{\bar\alpha_{t-1}}x^{(0)} \right\|^{2}}{2(1-\bar\alpha_{t-1})} 2βtx(t)αtx(t1)22(1αˉt1)x(t1)αˉt1x(0)2

将指数项展开并合并同类项之后可以得到关于x(t−1)x^{(t-1)}x(t1)的二次项系数和一次项系数:

  1. 二次项系数:

−12(αtβt+11−αˉt−1)(x(t−1))2 -\frac{1}{2}\left( \frac{\alpha_t}{\beta_t} +\frac{1}{1-\bar\alpha_{t-1}}\right)\left(x^{(t-1)}\right)^{2} 21(βtαt+1αˉt11)(x(t1))2

  1. 一次项系数:

−12(αtβtx(t)+αˉt−11−αˉt−1x(0))x(t−1) -\frac{1}{2}\left( \frac{\sqrt{\alpha_t}}{\beta_t}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x^{(0)} \right)x^{(t-1)} 21(βtαtx(t)+1αˉt1αˉt1x(0))x(t1)

那么分别对应到一个典型高斯分布的指数项形式:
−(x−μ)22σ2=−x22σ2+μσ2x−μ22σ2 -\frac{(x-\mu)^2}{2\sigma^2}=-\frac{x^2}{2\sigma^2}+\frac{\mu}{\sigma^2}x-\frac{\mu^2}{2\sigma^2} 2σ2(xμ)2=2σ2x2+σ2μx2σ2μ2
通过对比,就可以得到后验分布的均值和方差:

  1. 方差:

1β~t=αtβt+11−αˉt−1⟹β~t=1−αˉt−11−αˉtβt \frac{1}{\tilde{\beta}_t} =\frac{\alpha_t}{\beta_t} +\frac{1}{1-\bar\alpha_{t-1}}\Longrightarrow \tilde{\beta}_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t β~t1=βtαt+1αˉt11β~t=1αˉt1αˉt1βt

  1. 均值:

μ~tβ~t=αtβtx(t)+αˉt−11−αˉt−1x(0)⟹μ~t=αt(1−αˉt−1)1−αˉtx(t)+αˉt−1βt1−αˉtx(0) \frac{\tilde{\mu}_t}{\tilde{\beta}_t} =\frac{\sqrt{\alpha_t}}{\beta_t}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x^{(0)} \Longrightarrow\tilde{\mu}_t=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x^{(0)} β~tμ~t=βtαtx(t)+1αˉt1αˉt1x(0)μ~t=1αˉtαt(1αˉt1)x(t)+1αˉtαˉt1βtx(0)

再根据前向过程x(t)=αˉt⋅x(0)+1−αˉt⋅ϵx^{(t)} =\sqrt{\bar\alpha_t}\cdot x^{(0)}+\sqrt{1-\bar\alpha_t}\cdot\epsilonx(t)=αˉtx(0)+1αˉtϵ,解得:
x(0)=1αˉt(x(t)−1−αˉt⋅ϵ) x^{(0)} = \frac{1}{\sqrt{\bar\alpha_t}}\left( x^{(t)}-\sqrt{1-\bar\alpha_t}\cdot\epsilon \right) x(0)=αˉt1(x(t)1αˉtϵ)
x(0)x^{(0)}x(0)带入均值表达式可得:
μ~t=1αt(x(t)−βt1−αˉtϵ) \tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left( x^{(t)}-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon \right) μ~t=αt1(x(t)1αˉtβtϵ)

因此,如果神经网络预测的反向过程的均值μθ\mu_\thetaμθ为:
μθ(x(t),t)=1αˉt(x(t)−βt1−αˉtϵθ(x(t),t)) \mu_{\theta}\left( x^{(t)},t \right) =\frac{1}{\sqrt{\bar\alpha_t}}\left( x^{(t)}-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}\left( x^{(t)},t \right) \right) μθ(x(t),t)=αˉt1(x(t)1αˉtβtϵθ(x(t),t))
DKLD_{KL}DKL就可以简化为:
DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))=12σ2∥μ~t−μθ(x(t),t)∥2=βt22β~tαt(1−αˉt)∥ϵ−ϵθ(x(t),t)∥2 \begin{aligned} &D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p_{\theta}\left( x^{(t-1)}|x^{(t)} \right) \right) \\[5pt] = &\frac{1}{2\sigma^2}\left\| \tilde{\mu}_t-\mu_{\theta}\left( x^{(t)},t \right) \right\|^{2} \\[5pt] = &\frac{\beta_{t}^{2}}{2\tilde{\beta}_{t}\alpha_t(1-\bar\alpha_t)}\left\| \epsilon-\epsilon_{\theta}\left( x^{(t)},t \right) \right\|^{2} \end{aligned} ==DKL(q(x(t1)x(t),x(0))pθ(x(t1)x(t)))2σ21μ~tμθ(x(t),t)22β~tαt(1αˉt)βt2ϵϵθ(x(t),t)2
忽略权重系数之后,最终损失函数为:
Lt(θ)=Ex(0),ϵ,t[∥ϵ−ϵθ(x(t),t)∥2] \mathcal{L}_t(\theta) = \mathbf{E}_{x^(0),\epsilon,t}\left[ \left\| \epsilon-\epsilon_{\theta}\left( x^{(t)},t \right) \right\|^{2} \right] Lt(θ)=Ex(0),ϵ,t[ϵϵθ(x(t),t)2]

  1. 高斯假设:前向和反向过程均为高斯分布,且方差固定。
  2. 均值匹配:KL散度简化为均值差的MSE。
  3. 噪声预测:通过参数化技巧,将均值预测转换为对噪声ϵ\epsilonϵ的直接预测。
  4. 损失函数:最终损失函数为预测噪声与真实噪声的MSE。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值