扩散模型奠基与DDPM
前向过程
- 数据分布:q(x(0))q(x^{(0)})q(x(0))
- 目标分布:一个容易解析的简单分布π(x(T))\pi\left( x^{(T)} \right)π(x(T))
- 马尔可夫内核TπT_{\pi}Tπ:一个一般性的马尔可夫矩阵,将当前状态映射到下一个状态的概率
q(x(t)∣x(t−1))=Tπ(x(t)∣x(t−1);βt)q(x(0⋯T))=q(x(0))∏t=1Tq(x(t)∣x(t−1)) q\left(x^{(t)}|x^{(t-1)}\right) = T_{\pi}\left( x^{(t)}|x^{(t-1)};\beta_{t} \right) \\ q\left( x^{(0\cdots T)} \right) = q\left(x^{(0)}\right)\prod_{t=1}^{T}{q\left( x^{(t)}|x^{(t-1)} \right)} q(x(t)∣x(t−1))=Tπ(x(t)∣x(t−1);βt)q(x(0⋯T))=q(x(0))t=1∏Tq(x(t)∣x(t−1))
反向过程
- 原始分布:简单分布π(x(T))\pi\left( x^{(T)} \right)π(x(T))
- 目标分布:数据分布
p(x(T))=π(x(T))p(x(0⋯T))=p(x(T))∏t=1Tp(x(t−1)∣x(t)) p\left( x^{(T)} \right) = \pi\left( x^{(T)} \right) \\ p\left( x^{(0\cdots T)} \right) = p\left(x^{(T)}\right)\prod_{t=1}^{T}{p\left( x^{(t-1)}|x^{(t)} \right)} p(x(T))=π(x(T))p(x(0⋯T))=p(x(T))t=1∏Tp(x(t−1)∣x(t))
关键假设:当每一次加噪的幅度非常小,如果正向过程是高斯分布(或者二项分布),则反向过程也是高斯分布(或者二项分布)因此我们可以直接根据非常少的参数就能够把反向的过程确定下来。
详解反向过程
p(x(0⋯T))=p(x(T))∏t=1Tp(x(t−1)∣x(t))p(x(0))=∫dx(1)∫dx(2)⋯∫x(T)p(x(0⋯T))=∫dx(1⋯T)p(x(0⋯T))q(x(1⋯T)∣x(0))q(x(1⋯T)∣x(0))=∫dx(1⋯T)q(x(1⋯T)∣x(0))p(x(0⋯T))q(x(1⋯T)∣x(0))=∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1)) \begin{aligned} p\left( x^{(0\cdots T)} \right) &= p\left(x^{(T)}\right)\prod_{t=1}^{T}{p\left( x^{(t-1)}|x^{(t)} \right)} \\[5pt] p\left( x^{(0)} \right) &= \int{dx^{(1)}\int{dx^{(2)}\cdots\int{x^{(T)}p\left( x^{(0\cdots T)} \right)}}} \\[5pt] &=\int{dx^{(1\cdots T)}p\left( x^{(0\cdots T)} \right)\frac{q\left( x^{(1\cdots T)}|x^{(0)} \right)}{q\left( x^{(1\cdots T)}|x^{(0)} \right)}} \\[5pt] &=\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\frac{p\left( x^{(0\cdots T)} \right)}{q\left( x^{(1\cdots T)}|x^{(0)} \right)}} \\[5pt] &=\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\cdot p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}} \end{aligned} p(x(0⋯T))p(x(0))=p(x(T))t=1∏Tp(x(t−1)∣x(t))=∫dx(1)∫dx(2)⋯∫x(T)p(x(0⋯T))=∫dx(1⋯T)p(x(0⋯T))q(x(1⋯T)∣x(0))q(x(1⋯T)∣x(0))=∫dx(1⋯T)q(x(1⋯T)∣x(0))q(x(1⋯T)∣x(0))p(x(0⋯T))=∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))
对数似然及其ELBO
现在我们已经有了p(x(0))p\left(x^{(0)}\right)p(x(0))的一个比较好的解析方法,因此可以用来比较模型的输出值p(x(0))p\left(x^{(0)}\right)p(x(0))和真实数据q(x(0))q\left(x^{(0)}\right)q(x(0))之间的匹配程度,用对数似然函数即可:
L=∫dx(0)q(x(0))logp(x(0))=∫dx(0)q(x(0))log[∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]
\begin{aligned}
L &= \int{dx^{(0)}q\left(x^{(0)}\right)\log{p\left( x^{(0)} \right)}}
\\[5pt]
&= \int{dx^{(0)}q\left(x^{(0)}\right)\log{\left[\int{dx^{(1\cdots T)}q\left( x^{(1\cdots T)}|x^{(0)} \right)\cdot p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}}\right]}}
\end{aligned}
L=∫dx(0)q(x(0))logp(x(0))=∫dx(0)q(x(0))log[∫dx(1⋯T)q(x(1⋯T)∣x(0))⋅p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))]
由Jenson不等式可知,对于任意的上凸函数,都有f[E(x)]≥E[f(x)]f[E(x)]\ge E[f(x)]f[E(x)]≥E[f(x)],而log\loglog函数就是一个上凸函数,因此我们可以得到对数似然函数的下界:
L=∫dx(0)q(x(0))logEx(1⋯T)∼q(x(1⋯T)∣x(0))[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]≥∫dx(0)q(x(0))Ex(1⋯T)∼q(x(1⋯T)∣x(0))log[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0)dx(1⋯T)q(x(0))q(x(1⋯T)∣x(0))log[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0⋯T)q(x(0⋯T))log[p(x(T))∏t=1Tp(x(t−1)∣x(t))q(x(t)∣x(t−1))]=∫dx(0⋯T)q(x(0⋯T))[∑t=1Tlogp(x(t−1)∣x(t))q(x(t)∣x(t−1))+logp(x(T))]=K
\begin{aligned}
L &=\int{dx^{(0)}q\left(x^{(0)}\right)\log{E_{x^{(1\cdots T)}\sim q\left( x^{(1\cdots T)}|x^{(0)} \right)}\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}
\\[5pt]
&\ge \int{dx^{(0)}q\left(x^{(0)}\right)E_{x^{(1\cdots T)}\sim q\left( x^{(1\cdots T)}|x^{(0)} \right)}\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}
\\[5pt]
&=\int{dx^{(0)}dx^{(1\cdots T)}q\left( x^{(0)} \right)q\left( x^{(1\cdots T)}|x^{(0)} \right)\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}
\\[5pt]
&=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[p\left( x^{(T)} \right)\prod_{t=1}^{T}\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}
\\[5pt]
&=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\left[\sum_{t=1}^{T}\log{\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}}+\log{p\left( x^{(T)} \right)}\right]}
\\[5pt]
&= K
\end{aligned}
L=∫dx(0)q(x(0))logEx(1⋯T)∼q(x(1⋯T)∣x(0))[p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))]≥∫dx(0)q(x(0))Ex(1⋯T)∼q(x(1⋯T)∣x(0))log[p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))]=∫dx(0)dx(1⋯T)q(x(0))q(x(1⋯T)∣x(0))log[p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))]=∫dx(0⋯T)q(x(0⋯T))log[p(x(T))t=1∏Tq(x(t)∣x(t−1))p(x(t−1)∣x(t))]=∫dx(0⋯T)q(x(0⋯T))[t=1∑Tlogq(x(t)∣x(t−1))p(x(t−1)∣x(t))+logp(x(T))]=K
接下来我们可以对这个下界KKK继续进行处理:由(1)(1)(1)式可知,我们可以把q(x(0⋯T))q\left( x^{(0\cdots T)} \right)q(x(0⋯T))转换为q(x(0))∏t=1Tq(x(t)∣x(t−1))q\left(x^{(0)}\right)\prod_{t=1}^{T}{q\left( x^{(t)}|x^{(t-1)} \right)}q(x(0))∏t=1Tq(x(t)∣x(t−1)),后面的条件概率刚好可以和前面的积分抵消形成边缘概率分布,最终得到:
K=∫dx(0⋯T)q(x(0⋯T))∑t=1Tlog[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(T)q(x(T))logp(x(T))=∫dx(0⋯T)q(x(0⋯T))∑t=1Tlog[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(T)π(x(T))logπ(x(T))=∑t=1T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]−Hp(X(T))
\begin{aligned}
K &=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\sum_{t=1}^{T}\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}+\int{dx^{(T)}q\left( x^{(T)} \right)\log{p\left( x^{(T)} \right)}}
\\[5pt]
&=\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\sum_{t=1}^{T}\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}+\int{dx^{(T)}\pi\left( x^{(T)} \right)\log{\pi\left( x^{(T)} \right)}}
\\[5pt]
&=\sum_{t=1}^{T}\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}-H_{p}\left( X^{(T)} \right)
\end{aligned}
K=∫dx(0⋯T)q(x(0⋯T))t=1∑Tlog[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]+∫dx(T)q(x(T))logp(x(T))=∫dx(0⋯T)q(x(0⋯T))t=1∑Tlog[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]+∫dx(T)π(x(T))logπ(x(T))=t=1∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]−Hp(X(T))
接下来我们要考虑边界点,这需要进行单独考虑,就好比抓住绳子的两端之后要并在一起才能把绳子完美地对折:
p(x(0)∣x(1))=q(x(1)∣x(0))π(x(0))π(x(1))=Tπ(x(0)∣x(1);β1)
p\left( x^{(0)}|x^{(1)} \right) = q\left( x^{(1)}|x^{(0)} \right)\frac{\pi\left( x^{(0)} \right)}{\pi\left( x^{(1)} \right)} = T_{\pi}\left( x^{(0)}|x^{(1)};\beta_{1} \right)
p(x(0)∣x(1))=q(x(1)∣x(0))π(x(1))π(x(0))=Tπ(x(0)∣x(1);β1)
K=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log[p(x(0)∣x(1))q(x(1)∣x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log[q(x(1)∣x(0))π(x(0))q(x(1)∣x(0))π(x(1))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]+∫dx(0)∫dx(1)q(x(0),x(1))log[π(x(0))π(x(1))]−Hp(X(T)) \begin{aligned} K&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{p\left( x^{(0)}|x^{(1)} \right)}{q\left( x^{(1)}|x^{(0)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{q\left( x^{(1)}|x^{(0)} \right)\pi\left( x^{(0)} \right)}{q\left( x^{(1)}|x^{(0)} \right)\pi\left( x^{(1)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \\[5pt] &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}+{\int{dx^{(0)}\int{dx^{(1)}q\left( x^{(0)},x^{(1)} \right)\log{\left[\frac{\pi\left( x^{(0)} \right)}{\pi\left( x^{(1)} \right)}\right]}}}}-H_{p}\left( X^{(T)} \right) \end{aligned} K=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]+∫dx(0)∫dx(1)q(x(0),x(1))log[q(x(1)∣x(0))p(x(0)∣x(1))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]+∫dx(0)∫dx(1)q(x(0),x(1))log[q(x(1)∣x(0))π(x(1))q(x(1)∣x(0))π(x(0))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]+∫dx(0)∫dx(1)q(x(0),x(1))log[π(x(1))π(x(0))]−Hp(X(T))
由于前向加噪的初始步长非常小,实际上通过积分得到了两个Entropy可以直接消去,最终得到:
K=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t)∣x(t−1),x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))q(x(t−1)∣x(0))q(x(t)∣x(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]+∑t=2T[Hq(X(t)∣X(0))−Hq(X(t−1)∣X(0))]−Hp(X(T))=∑t=2T∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T))
\begin{aligned}
K &=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right)
\\[5pt]
&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t)}|x^{(t-1)},x^{(0)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right)
\\[5pt]
&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\frac{q\left( x^{(t-1)}|x^{(0)} \right)}{q\left( x^{(t)}|x^{(0)} \right)}\right]}}}-H_{p}\left( X^{(T)} \right)
\\[5pt]
&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}}+\sum_{t=2}^{T}\left[ H_q\left( X^{(t)}|X^{(0)} \right) - H_q\left( X^{(t-1)}|X^{(0)} \right) \right]-H_{p}\left( X^{(T)} \right)
\\[5pt]
&=\sum_{t=2}^{T}{\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}}+ H_q\left( X^{(T)}|X^{(0)} \right) - H_q\left( X^{(1)}|X^{(0)} \right) - H_{p}\left( X^{(T)} \right)
\end{aligned}
K=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1))p(x(t−1)∣x(t))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t)∣x(t−1),x(0))p(x(t−1)∣x(t))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))q(x(t)∣x(0))q(x(t−1)∣x(0))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]+t=2∑T[Hq(X(t)∣X(0))−Hq(X(t−1)∣X(0))]−Hp(X(T))=t=2∑T∫dx(0⋯T)q(x(0⋯T))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T))
接下来处理前面的大块头:由于log\loglog函数里面只和x(t),x(t−1),x(0)x^{(t)},x^{(t-1)},x^{(0)}x(t),x(t−1),x(0)有关,所以对于每一项积分,我们都可以把其他的积分全部消掉变成1:
∫dx(0⋯T)q(x(0⋯T))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t−1)dx(t)q(x(0))q(x(t−1))q(x(t))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t−1)dx(t)q(x(0),x(t))q(x(t−1)∣x(t),x(0))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=∫dx(0)dx(t)q(x(0),x(t))∫dx(t−1)q(x(t−1)∣x(t),x(0))log[p(x(t−1)∣x(t))q(x(t−1)∣x(t),x(0))]=−∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))
\begin{aligned}
&\int{dx^{(0\cdots T)}q\left( x^{(0\cdots T)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}
\\[5pt]
=&\int{dx^{(0)}dx^{(t-1)}dx^{(t)}q\left( x^{(0)} \right)q\left( x^{(t-1)} \right)q\left( x^{(t)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}
\\[5pt]
=&\int{dx^{(0)}dx^{(t-1)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}
\\[5pt]
=&\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)\int{dx^{(t-1)}q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\log{\left[\frac{p\left( x^{(t-1)}|x^{(t)} \right)}{q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)}\right]}}}
\\[5pt]
=&-\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)}
\end{aligned}
====∫dx(0⋯T)q(x(0⋯T))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]∫dx(0)dx(t−1)dx(t)q(x(0))q(x(t−1))q(x(t))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]∫dx(0)dx(t−1)dx(t)q(x(0),x(t))q(x(t−1)∣x(t),x(0))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]∫dx(0)dx(t)q(x(0),x(t))∫dx(t−1)q(x(t−1)∣x(t),x(0))log[q(x(t−1)∣x(t),x(0))p(x(t−1)∣x(t))]−∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))
因此我们可以用KLKLKL散度来表示KKK:
K=−∑t=2T∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T))
\begin{aligned}
K =&-\sum_{t=2}^{T}\int{dx^{(0)}dx^{(t)}q\left( x^{(0)},x^{(t)} \right)D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)}
\\
&+ H_q\left( X^{(T)}|X^{(0)} \right) - H_q\left( X^{(1)}|X^{(0)} \right) - H_{p}\left( X^{(T)} \right)
\end{aligned}
K=−t=2∑T∫dx(0)dx(t)q(x(0),x(t))DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))+Hq(X(T)∣X(0))−Hq(X(1)∣X(0))−Hp(X(T))
由于后面几个熵的值都是常数,我们只能尝试通过最小化DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p\left( x^{(t-1)}|x^{(t)} \right) \right)DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))使得KKK尽可能达到最大,因此LLL的下界尽可能增大。
前向扩散核、反向扩散核
在前向扩散的时候,我们可以自定义加噪公式:
q(x(t)∣x(t−1))=N(x(t);1−βt⋅x(t−1),Iβt)x(t)=1−βt⋅x(t−1)+βt⋅ϵt, ϵt∼N(0,I)
q\left(x^{(t)}|x^{(t-1)}\right) =\mathcal{N}\left( x^{(t)};\sqrt{1-\beta_{t}}\cdot x^{(t-1)},\mathbf{I}\beta_{t} \right)
\\[5pt]
x^{(t)} = \sqrt{1-\beta_t}\cdot x^{(t-1)}+\beta_t\cdot\epsilon_{t},\;\epsilon_{t}\sim\mathcal{N}(0,\mathbf{I})
q(x(t)∣x(t−1))=N(x(t);1−βt⋅x(t−1),Iβt)x(t)=1−βt⋅x(t−1)+βt⋅ϵt,ϵt∼N(0,I)
这里的βt\beta_{t}βt是自定义的,你可以选择线性增长或者三角函数型增长。通过每一步加噪,我们将分布的均值逐渐往0逼近,并且方差越来越大,最后形成一个N(x(T);0,I)\mathcal{N}\left( x^{(T)};0,\mathbf{I} \right)N(x(T);0,I)的标准正态分布。
反向过程中需要学习两个参数:
p(x(t−1)∣x(t))=N(x(t−1);fμ(x(t);t),fΣ(x(t);t))
p\left( x^{(t-1)}|x^{(t)} \right) = \mathcal{N}\left( x^{(t-1)};f_{\mu}\left( x^{(t)};t \right),f_{\Sigma}\left( x^{(t)};t \right) \right)
p(x(t−1)∣x(t))=N(x(t−1);fμ(x(t);t),fΣ(x(t);t))
即每一次反向过程中的均值和方差。
简化表达 (DDPM开始发力)
1. α\alphaα的引入
很显然,如果我们真的按照每一步加噪、去噪来训练模型的话,那么模型训练的时间相当长,因为每一次我都要以上一次加噪(去噪)的结果为输入,因此在前面的时间步进行训练的时候,后面的时间步只能等待。
我们发现从x(0)x^{(0)}x(0)到x(t)x^{(t)}x(t)的累积过程可以递归展开:
x(1)=1−β1⋅x(0)+β1⋅ϵ1x(2)=1−β2⋅x(1)+β2⋅ϵ2=1−β2⋅(1−β1⋅x(0)+β1⋅ϵ1)+β2⋅ϵ2=(1−β2)(1−β1)⋅x(0)+(1−β2)β1⋅ϵ1+β2⋅ϵ2
\begin{aligned}
x^{(1)} &= \sqrt{1-\beta_1}\cdot x^{(0)}+\beta_1\cdot\epsilon_{1}
\\[10pt]
x^{(2)} &= \sqrt{1-\beta_2}\cdot x^{(1)}+\beta_2\cdot\epsilon_{2}
\\[5pt]
&= \sqrt{1-\beta_2}\cdot\left( \sqrt{1-\beta_1}\cdot x^{(0)}+\beta_1\cdot\epsilon_{1} \right)+\beta_2\cdot\epsilon_{2}
\\[5pt]
&=\sqrt{(1-\beta_2)(1-\beta_1)}\cdot x^{(0)}+\sqrt{(1-\beta_2)\beta_1}\cdot\epsilon_{1}+\sqrt{\beta_2}\cdot\epsilon_2
\end{aligned}
x(1)x(2)=1−β1⋅x(0)+β1⋅ϵ1=1−β2⋅x(1)+β2⋅ϵ2=1−β2⋅(1−β1⋅x(0)+β1⋅ϵ1)+β2⋅ϵ2=(1−β2)(1−β1)⋅x(0)+(1−β2)β1⋅ϵ1+β2⋅ϵ2
由于ϵ1,ϵ2\epsilon_1,\epsilon_2ϵ1,ϵ2相互独立并且服从N(0,I)\mathcal{N}(0,\mathbf{I})N(0,I),因此我们可以把他们俩合并成一个高斯分布:
((1−β2)β1⋅ϵ1+β2⋅ϵ2)∼N(0,[(1−β2)β1+β2]I)=N(0,[1−(1−β1)(1−β2)]I)
\begin{aligned}
\left(\sqrt{(1-\beta_2)\beta_1}\cdot\epsilon_{1}+\sqrt{\beta_2}\cdot\epsilon_2\right)&\sim\mathcal{N}\left( 0,[(1-\beta_2)\beta_1+\beta_2]\mathbf{I} \right)
\\[5pt]
&= \mathcal{N}(0,[1-(1-\beta_1)(1-\beta_2)]\mathbf{I})
\end{aligned}
((1−β2)β1⋅ϵ1+β2⋅ϵ2)∼N(0,[(1−β2)β1+β2]I)=N(0,[1−(1−β1)(1−β2)]I)
接下来就可以推广到ttt步:定义αt=1−βt, αˉt=∏s=1tαs\alpha_t = 1-\beta_t,\;\bar\alpha_t = \prod_{s=1}^{t}\alpha_sαt=1−βt,αˉt=∏s=1tαs,则x(t)x^{(t)}x(t)可以表示为:
x(t)=αˉt⋅x(0)+1−αˉt⋅ϵ, ϵ∼N(0,I)q(x(t)∣x(0))=N(x(t);αˉt⋅x(0),(1−αˉt)I)
x^{(t)} =\sqrt{\bar\alpha_t}\cdot x^{(0)}+\sqrt{1-\bar\alpha_t}\cdot\epsilon,\;\epsilon\sim\mathcal{N}(0,\mathbf{I})
\\[5pt]
q\left( x^{(t)}|x^{(0)} \right) = \mathcal{N}\left( x^{(t)};\sqrt{\bar\alpha_t}\cdot x^{(0)},(1-\bar\alpha_t)\mathbf{I} \right)
x(t)=αˉt⋅x(0)+1−αˉt⋅ϵ,ϵ∼N(0,I)q(x(t)∣x(0))=N(x(t);αˉt⋅x(0),(1−αˉt)I)
这个表达式仅依赖于αˉt\bar\alpha_tαˉt,即αt\alpha_tαt的累积乘积,避免了逐步计算噪声项的复杂性。
2. DKLD_{KL}DKL在高斯分布条件下的简化与损失函数
上文已经给出,我们的目标是最小化DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p_{\theta}\left( x^{(t-1)}|x^{(t)} \right) \right)DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))使得KKK尽可能达到最大,其中θ\thetaθ表示神经网络的参数
对于两个高斯分布N(μ1;σ2I),N(μ2;σ2I)\mathcal{N}(\mu_1;\sigma^{2}\mathbf{I}),\mathcal{N}(\mu_2;\sigma^{2}\mathbf{I})N(μ1;σ2I),N(μ2;σ2I),他们之间的DKLD_{KL}DKL可以用以下方式表示:
DKL(N(μ1;σ2I)∥N(μ2;σ2I))=12σ2∥μ1−μ2∥2
D_{KL}\left( \mathcal{N}(\mu_1;\sigma^{2}\mathbf{I})\parallel\mathcal{N}(\mu_2;\sigma^{2}\mathbf{I}) \right) = \frac{1}{2\sigma^2}\left\| \mu_1-\mu_2 \right\|^{2}
DKL(N(μ1;σ2I)∥N(μ2;σ2I))=2σ21∥μ1−μ2∥2
接下来我们需要计算前向过程的真实后验分布,即q(x(t−1)∣x(t),x(0))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right)q(x(t−1)∣x(t),x(0)),它也遵循高斯分布,不过比较复杂:
q(x(t−1)∣x(t),x(0))=N(x(t−1);μ~t(x(t),x(0)),β~tI)
q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = \mathcal{N}\left( x^{(t-1)};\tilde{\mu}_{t}\left( x^{(t)},x^{(0)} \right),\tilde{\beta}_{t}\mathbf{I} \right)
q(x(t−1)∣x(t),x(0))=N(x(t−1);μ~t(x(t),x(0)),β~tI)
在这里,均值和方差都有解析解,接下来推导其过程:
目标是求q(x(t−1)∣x(t),x(0))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right)q(x(t−1)∣x(t),x(0)),即已知x(t)x^{(t)}x(t)和x(0)x^{(0)}x(0)时,x(t−1)x^{(t-1)}x(t−1)的条件分布。
根据贝叶斯定义:
q(x(t−1)∣x(t),x(0))=q(x(t)∣x(t−1),x(0))⋅q(x(t−1)∣x(0))q(x(t)∣x(0)) q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = \frac{q\left( x^{(t)}|x^{(t-1)},x^{(0)}\right)\cdot q\left( x^{(t-1)}|x^{(0)} \right)}{q\left( x^{(t)}|x^{(0)} \right)} q(x(t−1)∣x(t),x(0))=q(x(t)∣x(0))q(x(t)∣x(t−1),x(0))⋅q(x(t−1)∣x(0))
由于前向过程是马尔可夫链,q(x(t−1)∣x(t),x(0))=q(x(t−1)∣x(t))q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) = q\left( x^{(t-1)}|x^{(t)}\right)q(x(t−1)∣x(t),x(0))=q(x(t−1)∣x(t)),因此:
q(x(t−1)∣x(t),x(0))∝q(x(t)∣x(t−1))⋅q(x(t−1)∣x(0)) q\left( x^{(t-1)}|x^{(t)},x^{(0)}\right) \propto q\left( x^{(t)}|x^{(t-1)}\right)\cdot q\left( x^{(t-1)}|x^{(0)}\right) q(x(t−1)∣x(t),x(0))∝q(x(t)∣x(t−1))⋅q(x(t−1)∣x(0))
接下来把各个高斯分布显式地写出:
- 前向单步转移概率:
q(x(t)∣x(t−1))∝exp(−∥x(t)−αtx(t−1)∥22βt) q\left( x^{(t)}|x^{(t-1)}\right) \propto \exp{\left( -\frac{\left\| x^{(t)}-\sqrt{\alpha_t}x^{(t-1)} \right\|^{2}}{2\beta_t} \right)} q(x(t)∣x(t−1))∝exp(−2βtx(t)−αtx(t−1)2)
- 从x(0)x^{(0)}x(0)到x(t−1)x^{(t-1)}x(t−1)的累积分布:
q(x(t−1)∣x(0))∝exp(−∥x(t−1)−αˉt−1x(0)∥22(1−αˉt−1)) q\left( x^{(t-1)}|x^{(0)}\right)\propto\exp{\left( -\frac{\left\| x^{(t-1)}-\sqrt{\bar\alpha_{t-1}}x^{(0)} \right\|^{2}}{2(1-\bar\alpha_{t-1})} \right)} q(x(t−1)∣x(0))∝exp(−2(1−αˉt−1)x(t−1)−αˉt−1x(0)2)
- 联合分布的指数项,即后验分布的指数项为:
−∥x(t)−αtx(t−1)∥22βt−∥x(t−1)−αˉt−1x(0)∥22(1−αˉt−1) -\frac{\left\| x^{(t)}-\sqrt{\alpha_t}x^{(t-1)} \right\|^{2}}{2\beta_t} -\frac{\left\| x^{(t-1)}-\sqrt{\bar\alpha_{t-1}}x^{(0)} \right\|^{2}}{2(1-\bar\alpha_{t-1})} −2βtx(t)−αtx(t−1)2−2(1−αˉt−1)x(t−1)−αˉt−1x(0)2
将指数项展开并合并同类项之后可以得到关于x(t−1)x^{(t-1)}x(t−1)的二次项系数和一次项系数:
- 二次项系数:
−12(αtβt+11−αˉt−1)(x(t−1))2 -\frac{1}{2}\left( \frac{\alpha_t}{\beta_t} +\frac{1}{1-\bar\alpha_{t-1}}\right)\left(x^{(t-1)}\right)^{2} −21(βtαt+1−αˉt−11)(x(t−1))2
- 一次项系数:
−12(αtβtx(t)+αˉt−11−αˉt−1x(0))x(t−1) -\frac{1}{2}\left( \frac{\sqrt{\alpha_t}}{\beta_t}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x^{(0)} \right)x^{(t-1)} −21(βtαtx(t)+1−αˉt−1αˉt−1x(0))x(t−1)
那么分别对应到一个典型高斯分布的指数项形式:
−(x−μ)22σ2=−x22σ2+μσ2x−μ22σ2 -\frac{(x-\mu)^2}{2\sigma^2}=-\frac{x^2}{2\sigma^2}+\frac{\mu}{\sigma^2}x-\frac{\mu^2}{2\sigma^2} −2σ2(x−μ)2=−2σ2x2+σ2μx−2σ2μ2
通过对比,就可以得到后验分布的均值和方差:
- 方差:
1β~t=αtβt+11−αˉt−1⟹β~t=1−αˉt−11−αˉtβt \frac{1}{\tilde{\beta}_t} =\frac{\alpha_t}{\beta_t} +\frac{1}{1-\bar\alpha_{t-1}}\Longrightarrow \tilde{\beta}_t = \frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t β~t1=βtαt+1−αˉt−11⟹β~t=1−αˉt1−αˉt−1βt
- 均值:
μ~tβ~t=αtβtx(t)+αˉt−11−αˉt−1x(0)⟹μ~t=αt(1−αˉt−1)1−αˉtx(t)+αˉt−1βt1−αˉtx(0) \frac{\tilde{\mu}_t}{\tilde{\beta}_t} =\frac{\sqrt{\alpha_t}}{\beta_t}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x^{(0)} \Longrightarrow\tilde{\mu}_t=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x^{(t)}+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x^{(0)} β~tμ~t=βtαtx(t)+1−αˉt−1αˉt−1x(0)⟹μ~t=1−αˉtαt(1−αˉt−1)x(t)+1−αˉtαˉt−1βtx(0)
再根据前向过程x(t)=αˉt⋅x(0)+1−αˉt⋅ϵx^{(t)} =\sqrt{\bar\alpha_t}\cdot x^{(0)}+\sqrt{1-\bar\alpha_t}\cdot\epsilonx(t)=αˉt⋅x(0)+1−αˉt⋅ϵ,解得:
x(0)=1αˉt(x(t)−1−αˉt⋅ϵ) x^{(0)} = \frac{1}{\sqrt{\bar\alpha_t}}\left( x^{(t)}-\sqrt{1-\bar\alpha_t}\cdot\epsilon \right) x(0)=αˉt1(x(t)−1−αˉt⋅ϵ)
将x(0)x^{(0)}x(0)带入均值表达式可得:
μ~t=1αt(x(t)−βt1−αˉtϵ) \tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}}\left( x^{(t)}-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon \right) μ~t=αt1(x(t)−1−αˉtβtϵ)
因此,如果神经网络预测的反向过程的均值μθ\mu_\thetaμθ为:
μθ(x(t),t)=1αˉt(x(t)−βt1−αˉtϵθ(x(t),t))
\mu_{\theta}\left( x^{(t)},t \right) =\frac{1}{\sqrt{\bar\alpha_t}}\left( x^{(t)}-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}\left( x^{(t)},t \right) \right)
μθ(x(t),t)=αˉt1(x(t)−1−αˉtβtϵθ(x(t),t))
则DKLD_{KL}DKL就可以简化为:
DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))=12σ2∥μ~t−μθ(x(t),t)∥2=βt22β~tαt(1−αˉt)∥ϵ−ϵθ(x(t),t)∥2
\begin{aligned}
&D_{KL}\left( q\left( x^{(t-1)}|x^{(t)},x^{(0)} \right)\parallel p_{\theta}\left( x^{(t-1)}|x^{(t)} \right) \right)
\\[5pt]
= &\frac{1}{2\sigma^2}\left\| \tilde{\mu}_t-\mu_{\theta}\left( x^{(t)},t \right) \right\|^{2}
\\[5pt]
= &\frac{\beta_{t}^{2}}{2\tilde{\beta}_{t}\alpha_t(1-\bar\alpha_t)}\left\| \epsilon-\epsilon_{\theta}\left( x^{(t)},t \right) \right\|^{2}
\end{aligned}
==DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t)))2σ21μ~t−μθ(x(t),t)22β~tαt(1−αˉt)βt2ϵ−ϵθ(x(t),t)2
忽略权重系数之后,最终损失函数为:
Lt(θ)=Ex(0),ϵ,t[∥ϵ−ϵθ(x(t),t)∥2]
\mathcal{L}_t(\theta) = \mathbf{E}_{x^(0),\epsilon,t}\left[ \left\| \epsilon-\epsilon_{\theta}\left( x^{(t)},t \right) \right\|^{2} \right]
Lt(θ)=Ex(0),ϵ,t[ϵ−ϵθ(x(t),t)2]
- 高斯假设:前向和反向过程均为高斯分布,且方差固定。
- 均值匹配:KL散度简化为均值差的MSE。
- 噪声预测:通过参数化技巧,将均值预测转换为对噪声ϵ\epsilonϵ的直接预测。
- 损失函数:最终损失函数为预测噪声与真实噪声的MSE。