从零实现大语言模型:深入理解带可训练权重的自注意力机制
自注意力机制概述
自注意力机制是现代大语言模型(如GPT系列)的核心组件之一。它允许模型在处理序列数据时,动态地关注输入序列的不同部分,从而更好地捕捉长距离依赖关系。本节我们将深入探讨带有可训练权重的自注意力机制实现。
自注意力机制的基本原理
自注意力机制的核心思想是通过三个关键组件——查询(Query)、键(Key)和值(Value)来计算输入序列中各个元素之间的相关性。与3.3节介绍的简化版本不同,本节实现的自注意力机制引入了可训练的权重矩阵,使模型能够学习如何更好地计算这些相关性。
查询、键和值的概念
- 查询(Query):代表当前需要关注的位置
- 键(Key):代表输入序列中所有可能被关注的位置
- 值(Value):包含实际要聚合的信息
这三个组件通过可训练的权重矩阵从输入向量转换而来,使得模型能够学习最适合当前任务的表示方式。
逐步实现自注意力机制
1. 初始化权重矩阵
首先需要初始化三个关键的权重矩阵:
W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))
这些矩阵将在训练过程中不断优化,使模型学会如何更好地计算注意力。
2. 计算查询、键和值向量
对于输入序列中的每个元素x,我们计算其对应的查询、键和值向量:
query = x @ W_query
key = x @ W_key
value = x @ W_value
这一步将输入向量投影到不同的空间,为后续的注意力计算做准备。
3. 计算注意力得分
注意力得分衡量查询与各个键之间的相关性:
attn_scores = queries @ keys.T
这里使用了点积运算来计算相似度,点积越大表示相关性越强。
4. 计算注意力权重
为了使注意力得分可比较,我们使用softmax函数进行归一化:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
这里除以键向量维度的平方根是为了防止梯度消失问题,这是缩放点积注意力的关键步骤。
5. 计算上下文向量
最后,我们使用注意力权重对值向量进行加权求和,得到最终的上下文向量:
context_vec = attn_weights @ values
这个上下文向量包含了模型认为在当前查询下最重要的信息。
封装为Python类
为了便于使用,我们可以将上述步骤封装为一个PyTorch模块:
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = nn.Linear(d_in, d_out)
self.W_key = nn.Linear(d_in, d_out)
self.W_value = nn.Linear(d_in, d_out)
def forward(self, x):
queries = self.W_query(x)
keys = self.W_key(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
这个类可以直接集成到更大的神经网络架构中。
关键实现细节
1. 缩放因子的重要性
在计算注意力权重时,我们使用键向量维度的平方根进行缩放:
attn_scores / keys.shape[-1]**0.5
这种缩放对于稳定训练至关重要,因为它可以防止点积结果过大导致softmax函数的梯度消失。
2. 权重初始化
使用nn.Linear
比手动初始化参数更优,因为:
- 它提供了更合理的默认初始化方案
- 可以方便地添加偏置项
- 内置了转置操作,使代码更简洁
3. 批处理支持
上述实现天然支持批处理,可以同时处理多个输入序列,这对于实际训练非常重要。
自注意力机制的优势
- 长距离依赖:可以捕捉序列中任意位置之间的关系
- 并行计算:所有位置的注意力可以同时计算
- 可解释性:注意力权重可以直观展示模型关注的重点
- 灵活性:通过可训练权重适应不同任务
总结
本节我们详细实现了带有可训练权重的自注意力机制,这是构建现代大语言模型的基础组件。理解这一机制对于后续实现更复杂的模型架构至关重要。在下一节中,我们将在此基础上引入因果掩码和多头注意力,进一步完善模型的表达能力。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考