废话不多说直接上模型:
这是一个非常经典的对话生成模型,叫做HRED(Hierarchical RNN Enconder-Decoder)。思路很简单,就是用一个RNN来建模前j−1j-1j−1句话,再用一个RNN来建模第jjj句话的k−1k-1k−1个词,然后再用一个RNN来解码第jjj句话的第kkk个词。
HRED模型的训练:给定一个词的上文,最大化这个词出现的对数似然(极大似然估计)
即:argmax(logpθ(wj,k∣w1,w2,...,wj−1,wj,1,wj,2,...,wj,k))argmax(logp_\theta(w_{j,k}|w_{1}, w_{2}, ..., w_{j-1}, w_{j, 1}, w_{j, 2}, ..., w_{j,k}))argmax(logpθ(wj,k∣w1,w2,...,wj−1,wj,1,wj,2,...,wj,k))
为了方便下文的推导,let www = wj,kw_{j,k}wj,k, ccc = w1,w2,...,wj−1,wj,1,wj,2,...,wj,kw_{1}, w_{2}, ..., w_{j-1}, w_{j, 1}, w_{j, 2}, ..., w_{j,k}w1,w2,...,wj−1,wj,1,wj,2,...,wj,k
训练目标简写为:argmax(logpθ(w∣c))argmax(logp_\theta(w|c))argmax(logpθ(w∣c))
而pθp_\thetapθ其实就是我们定义的RNN,RNN的公式网上一查就查到了,因此计算梯度,反向传播,参数更新,都是理所当然的事情,最后我们就得到了一个对话生成模型。
但是极大似然估计有一个问题,那就是模型容易置信度过高。而在真实的对话流程中,一句输入,其实可能有无数种回答的方式,每一种回答都是合理的(因为是闲聊嘛)。将模型的所有变量都建模为可见的参数θ\thetaθ,会导致倾向于生成单调的通用性回复,比如“我不知道”,“嗯嗯”,等。这种回复本质上是没错的,但是很无聊。
于是我们提出了下面这个模型:

这个模型叫做VHRED(Variational Hierarchical RNN Encoder-Decoder),引入一个不可观测的隐变量,用于生成对话回复。隐变量每次都采样于一个分布,这样,回复的多样性极大的增加了。
VHRED模型的训练:在生成每一个词的时候,引入了一个隐变量zzz,这个隐变量无法观测到,所以无法用参数θ\thetaθ建模。因此,训练目标重新写为:
argmax(logpθ(w∣c))=argmax(log∫pθ(w∣c,z)pθ(z∣c)dz)argmax(logp_\theta(w|c)) = argmax(log\int p_\theta(w|c,z)p_\theta(z|c) d_{z})argmax(logpθ(w∣c))=argmax(log∫pθ(w∣c,z)pθ(z∣c)dz)
问题出现了:zzz是一个高维变量,遍历所有zzz是不可行的!也就是说,等式右边的式子你是写不出来的,自然也无法计算梯度,进行梯度反向传播,参数更新。
万幸,提出VAE(Variational Autoencoder,VHRED的V就是从这里来的)的论文提供了一个复杂且精妙的解决方案,那就是引入变分推断,用一个认知网络qϕ(z∣c,w)q_\phi(z|c,w)qϕ(z∣c,w)拟合真实的后验分布pθ(z∣c,w)p_\theta(z|c,w)pθ(z∣c,w)。
(一大波推导来袭)
logpθ(w∣c)=Ez∼qϕ(z∣c,w)logpθ(w∣z)#样本分布与认知网络无关=Ezlog(pθ(w∣c,z)×pθ(z∣c)pθ(z∣c,w))=Ezlog(pθ(w∣c,z)×pθ(z∣c)pθ(z∣c,w)×qϕ(z∣c,w)qϕ(z∣c,w))=Ezlog(pθ(w∣c,z)×pθ(z∣c)qϕ(z∣c,w)×qϕ(z∣c,w)pθ(z∣c,w))=Ezlog(pθ(w∣c,z)−Ezlog(qϕ(z∣c,w)pθ(z∣c))+Ezlog(qϕ(z∣c,w)pθ(z∣c,w))=Ezlog(pθ(w∣c,z)−∫qϕ(z∣c,w)logqϕ(z∣c,w)pθ(z∣c)dz+∫qϕ(z∣c,w)qϕ(z∣c,w)pθ(z∣c,w)dz=Ez∼qϕ(z∣c,w)logpθ(w∣c,z)−KL(qϕ(z∣c,w)∣∣pθ(z∣c))+KL(qϕ(z∣c,w)∣∣pθ(z∣c,w)) \begin{aligned} logp_\theta(w|c) &= \mathbb{E}_{z \sim q_\phi(z|c,w)}logp_\theta(w|z) \#样本分布与认知网络无关\\ & = \mathbb{E}_zlog( \frac{p_\theta(w|c, z) \times p_\theta(z|c)}{p_\theta(z|c, w)})\\ & = \mathbb{E}_zlog(\frac{p_\theta(w|c, z) \times p_\theta(z|c)}{p_\theta(z|c, w)} \times \frac{q_\phi(z|c, w)}{q_\phi(z|c, w)}) \\ & = \mathbb{E}_zlog(p_\theta(w|c, z) \times \frac{p_\theta(z|c)}{q_\phi(z|c,w)} \times \frac{q_\phi(z|c, w)}{p_\theta(z|c, w)}) \\ & = \mathbb{E}_zlog(p_\theta(w|c, z) - \mathbb{E}_z log(\frac{q_\phi(z|c,w)}{p_\theta(z|c)}) + \mathbb{E}_z log(\frac{q_\phi(z|c, w)}{p_\theta(z|c, w)}) \\ & = \mathbb{E}_zlog(p_\theta(w|c, z) - \int q_\phi(z|c,w)log\frac{q_\phi(z|c,w)}{p_\theta(z|c)} d_z + \int q_\phi(z|c,w) \frac{q_\phi(z|c, w)}{p_\theta(z|c, w)}d_z \\ & = \mathbb{E}_{z \sim q_\phi(z|c,w)}logp_\theta(w|c,z) - KL(q_\phi(z|c,w)||p_\theta(z|c)) + KL(q_\phi(z|c,w)||p_\theta(z|c, w)) \end{aligned} logpθ(w∣c)=Ez∼qϕ(z∣c,w)logpθ(w∣z)#样本分布与认知网络无关=Ezlog(