注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图
10.10 束搜索
上一节介绍了如何训练输入和输出均为不定长序列的编码器—解码器。本节我们介绍如何使用编码器—解码器来预测不定长的序列。
上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号"<eos>“表示序列的终止。我们在接下来的讨论中也将沿用上一节的全部数学符号。为了便于讨论,假设解码器的输出是一段文本序列。设输出文本词典 Y \mathcal{Y} Y(包含特殊符号”<eos>“)的大小为 ∣ Y ∣ \left|\mathcal{Y}\right| ∣Y∣,输出序列的最大长度为 T ′ T' T′。所有可能的输出序列一共有 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|^{T'}) O(∣Y∣T′)种。这些输出序列中所有特殊符号”<eos>"后面的子序列将被舍弃。
10.10.1 贪婪搜索
让我们先来看一个简单的解决方案:贪婪搜索(greedy search)。对于输出序列任一时间步 t ′ t' t′,我们从 ∣ Y ∣ |\mathcal{Y}| ∣Y∣个词中搜索出条件概率最大的词
y t ′ = argmax y ∈ Y P ( y ∣ y 1 , … , y t ′ − 1 , c ) y _ { t ^ { \prime } } = \underset { y \in \mathcal { Y } } { \operatorname { argmax } } P \left( y | y _ { 1 } , \ldots , y _ { t ^ { \prime } - 1 } , c \right) yt′=y∈YargmaxP(y∣y1,…,yt′−1,c)
作为输出。一旦搜索出"<eos>"符号,或者输出序列长度已经达到了最大长度 T ′ T' T′,便完成输出。
我们在描述解码器时提到,基于输入序列生成输出序列的条件概率是 ∏ t ′ = 1 T ′ P ( y t ′ ∣ y 1 , … , y t ′ − 1 , c ) \prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}) ∏t′=1T′P(