注意力机制
在整个注意力过程中,模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。
假设有一个数据库,里面有所有一些作家和他们的书籍信息。现在我想读一些Rabindranath写的书:
在数据库中,作者名字类似于键,图书类似于值。查询的关键词Rabindranath是这个问题的键。所以需要计算查询和数据库的键(数据库中的所有作者)之间的相似度,然后返回最相似作者的值(书籍)。
同样,注意力有三个矩阵,分别是查询矩阵(Q)、键矩阵(K)和值矩阵(V)。它们中的每一个都具有与输入嵌入相同的维数。模型在训练中学习这些度量的值。
我们可以假设我们从每个单词中创建一个向量,这样我们就可以处理信息。对于每个单词,生成一个512维的向量。所有3个矩阵都是512x512(因为单词嵌入的维度是512)。对于每个标记嵌入,我们将其与所有三个矩阵(Q, K, V)相乘,每个标记将有3个长度为512的中间向量。
接下来计算分数,它是查询和键向量之间的点积。分数决定了当我们在某个位置编码单词时,对输入句子的其他部分的关注程度。
然后将点积除以关键向量维数的平方根。这种缩放是为了防止点积变得太大或太小(取决于正值或负值),因为这可能导致训练期间的数值不稳定。选择比例因子是为了确保点积的方差近似等于1。
然后通过softmax操作传递结果。这将分数标准化:它们都是正的,并且加起来等于1。softmax输出决定了我们应该从不同的单词中获取多少信息或特征(值),也就是在计算权重。
这里需要注意的一点是,为什么需要其他单词的信息/特征?因为我们的语言是有上下文含义的,一个相同的单词出现在不同的语境,含义也不一样。
最后一步就是计算softmax与这些值的乘积,并将它们相加。
可视化图解
上面逻辑都是文字内容,看起来有一些枯燥,下面我们可视化它的矢量化实现。这样可以更加深入的理解。
查询键和矩阵的计算方法如下
同样的方法可以计算键向量和值向量。
最后计算得分和注意力输出。
简单代码实现
import torch
import torch.nn as nn
from typing import List
def get_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vecto