交叉注意力机制(Cross-Attention)详细说明,包含代码示例详细注释
交叉注意力是注意力机制的一种变体,核心思想是让一个序列(查询序列)关注另一个序列(键值序列)。它在多模态任务(如图文匹配)和Transformer解码器中广泛应用,通过动态计算不同序列元素间的关联权重实现信息融合。
数学原理
给定查询序列 Q∈Rn×dkQ \in \mathbb{R}^{n \times d_k}Q∈Rn×dk 和键值序列 K,V∈Rm×dkK, V \in \mathbb{R}^{m \times d_k}K,V∈Rm×dk,交叉注意力计算过程为:
-
相似度计算:
Similarity=QKTdk \text{Similarity} = \frac{QK^T}{\sqrt{d_k}} Similarity=dkQKT
其中 dkd_kdk 为向量维度,缩放因子避免梯度消失 -
权重归一化:
A=softmax(QKTdk) A = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) A=softmax(dkQKT) -
加权输出:
Output=AV \text{Output} = AV Output=AV
================================================ -
输入格式
query/key/value
: 形状为 (S,N,E)(S, N, E)(S,N,E) 或 (N,S,E)(N, S, E)(N,S,E)(通过batch_first
参数控制)- SSS: 序列长度,NNN: 批次大小,EEE: 特征维度
- e.g:
Q:[N,S1,E]Q:[N, S1, E]Q:[N,S1,E]
K:[N,S2,E]=V:[N,S2,E]K:[N, S2, E] = V:[N, S2, E]K:[N,S2,E]=V:[N,S2,E]
Q∗KT=[N,S1,E]∗[N,E,S2]=[N,S1,S2]Q*K^T=[N,S1,E]*[N, E, S2]=[N,S1,S2]Q∗KT=[N,S1,E]∗[N,E,S2]=[N,S1,S2]
QK(Attention)∗V=[N,S1,S2]∗[N,S2,E]=[N,S1,E]QK(Attention)*V=[N,S1,S2]*[N, S2, E]=[N,S1,E]QK(Attention)∗V=[N,S1,S2]∗[N,S2,E]=[N,S1,E]
输出保持与query维度一致
-
输出
output
: 注意力结果,形状与query
相同attn_weights
: 注意力权重矩阵,形状为 (N×h,S,S)(N \times h, S, S)(N×h,S,S)
-
典型应用场景
- 自注意力:
query = key = value
- 编码器-解码器注意力:
query
来自解码器,key/value
来自编码器
- 自注意力:
提示:实际使用时需注意处理
key_padding_mask
(屏蔽填充符)和attn_mask
(防止未来信息泄露)。
Python代码nn.MultiheadAttention实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, query, key, value):
"""
query: [target_len, batch_size, embed_dim]
key: [source_len, batch_size, embed_dim]
value: [source_len, batch_size, embed_dim]
返回: [target_len, batch_size, embed_dim]
"""
attn_output, _ = self.multihead_attn(
query, key, value,
need_weights=False
)
return attn_output
# 示例用法
if __name__ == "__main__":
embed_dim = 512
num_heads = 8
batch_size = 4
src_len = 20 # 源序列长度
tgt_len = 15 # 目标序列长度
# 初始化模块
cross_attn = CrossAttention(embed_dim, num_heads)
# 生成模拟数据
query = torch.rand(tgt_len, batch_size, embed_dim) # 目标序列
key = value = torch.rand(src_len, batch_size, embed_dim) # 源序列
# 执行交叉注意力
output = cross_attn(query, key, value)
print("输出张量形状:", output.shape) # [15, 4, 512]
关键特性
- 动态权重分配:根据当前查询动态计算键值序列的权重
- 序列长度无关:可处理不同长度的输入序列
- 并行计算:矩阵运算支持GPU加速
- 可解释性:注意力权重可视化(需修改代码返回权重矩阵)
其中nn.MultiheadAttention
详解
nn.MultiheadAttention
是 PyTorch 中实现多头注意力机制的模块,广泛应用于 Transformer 架构(如 BERT、GPT 等)。其核心思想是将输入拆分为多个"头",并行计算注意力,最后合并结果以捕捉更丰富的上下文信息。
数学原理
设输入序列为 X∈RL×dmodelX \in \mathbb{R}^{L \times d_{\text{model}}}X∈RL×dmodel(LLL 为序列长度,dmodeld_{\text{model}}dmodel 为特征维度)。首先通过线性变换生成查询(Query)、键(Key)、值(Value):
Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^VQ=XWQ,K=XWK,V=XWV
其中 WQ,WK,WV∈Rdmodel×dkW^Q, W^K, W^V \in \mathbb{R}^{d_{\text{model}} \times d_k}WQ,WK,WV∈Rdmodel×dk。
多头注意力将 Q,K,VQ,K,VQ,K,V 拆分为 hhh 个头(每个头维度 dk=dmodel/hd_k = d_{\text{model}}/hdk=dmodel/h),独立计算缩放点积注意力:
headi=Softmax(QiKiTdk)Vi\text{head}_i = \text{Softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right)V_iheadi=Softmax(dkQiKiT)Vi
最终输出是各头结果的拼接与线性变换:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
关键参数
参数 | 说明 |
---|---|
embed_dim | 输入特征维度 dmodeld_{\text{model}}dmodel |
num_heads | 注意力头数量 hhh |
dropout | 注意力权重 dropout 概率 |
bias | 是否在线性层添加偏置 |
add_bias_kv | 是否为 K/V 添加额外偏置 |
kdim , vdim | 可指定 K/V 与 Q 不同的维度 |
Cross Attention 自写编码实现
Python 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, d_model: int, d_k: int, d_v: int):
super().__init__()
self.W_q = nn.Linear(d_model, d_k) # Query 投影
self.W_k = nn.Linear(d_model, d_k) # Key 投影
self.W_v = nn.Linear(d_model, d_v) # Value 投影
def forward(self,
query: torch.Tensor,
key_value: torch.Tensor) -> torch.Tensor:
"""
query: [batch_size, m, d_model]
key_value: [batch_size, n, d_model]
输出: [batch_size, m, d_v]
"""
# 投影变换
Q = self.W_q(query) # [batch_size, m, d_k]
K = self.W_k(key_value) # [batch_size, n, d_k]
V = self.W_v(key_value) # [batch_size, n, d_v]
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) # [batch_size, m, n]
scores /= torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))
# 归一化权重
attn_weights = F.softmax(scores, dim=-1) # [batch_size, m, n]
# 加权求和
output = torch.matmul(attn_weights, V) # [batch_size, m, d_v]
return output
#### 使用示例
# 参数设置
batch_size, m, n = 4, 10, 20
d_model, d_k, d_v = 512, 64, 64
# 创建数据
query = torch.randn(batch_size, m, d_model)
key_value = torch.randn(batch_size, n, d_model)
# 初始化模块
attn = CrossAttention(d_model, d_k, d_v)
# 前向传播
output = attn(query, key_value)
print(output.shape) # 输出: torch.Size([4, 10, 64])
关键特性
-
维度处理:
- Query 序列长度 mmm 与 Key-Value 序列长度 nnn 可不同
- 输出维度由 dvd_vdv 控制,与输入 dmodeld_{model}dmodel 解耦
-
计算效率:
- 复杂度为 O(m×n)O(m \times n)O(m×n)
- 使用矩阵乘法并行计算
-
缩放因子:
- 1dk\frac{1}{\sqrt{d_k}}dk1 防止点积过大导致 softmax 梯度消失
此实现可直接集成到 Transformer 架构中,适用于机器翻译、文本生成等跨序列建模任务。
典型应用场景
- 机器翻译:解码器关注编码器输出
- 图文检索:文本查询关注图像特征
- 语音识别:声学特征关注语言模型
- 多模态融合:跨模态特征交互
注:实际使用时需配合层归一化(LayerNorm)和残差连接(Residual Connection)构建完整模块,如Transformer解码器中的交叉注意力层。