【模型细节】Cross-Attention:多头交叉意力机制 (Multi-head Cross-Attention) 详细解释,使用 PyTorch代码示例说明

交叉注意力机制(Cross-Attention)详细说明,包含代码示例详细注释

交叉注意力是注意力机制的一种变体,核心思想是让一个序列(查询序列)关注另一个序列(键值序列)。它在多模态任务(如图文匹配)和Transformer解码器中广泛应用,通过动态计算不同序列元素间的关联权重实现信息融合。

数学原理

给定查询序列 Q∈Rn×dkQ \in \mathbb{R}^{n \times d_k}QRn×dk 和键值序列 K,V∈Rm×dkK, V \in \mathbb{R}^{m \times d_k}K,VRm×dk,交叉注意力计算过程为:

  1. 相似度计算
    Similarity=QKTdk \text{Similarity} = \frac{QK^T}{\sqrt{d_k}} Similarity=dkQKT
    其中 dkd_kdk 为向量维度,缩放因子避免梯度消失

  2. 权重归一化
    A=softmax(QKTdk) A = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) A=softmax(dkQKT)

  3. 加权输出
    Output=AV \text{Output} = AV Output=AV
    ================================================

  4. 输入格式

    • query/key/value: 形状为 (S,N,E)(S, N, E)(S,N,E)(N,S,E)(N, S, E)(N,S,E)(通过 batch_first 参数控制)
    • SSS: 序列长度,NNN: 批次大小,EEE: 特征维度
    • e.g:
      Q:[N,S1,E]Q:[N, S1, E]Q:[N,S1,E]
      K:[N,S2,E]=V:[N,S2,E]K:[N, S2, E] = V:[N, S2, E]K:[N,S2,E]=V:[N,S2,E]
      Q∗KT=[N,S1,E]∗[N,E,S2]=[N,S1,S2]Q*K^T=[N,S1,E]*[N, E, S2]=[N,S1,S2]QKT=[N,S1,E][N,E,S2]=[N,S1,S2]
      QK(Attention)∗V=[N,S1,S2]∗[N,S2,E]=[N,S1,E]QK(Attention)*V=[N,S1,S2]*[N, S2, E]=[N,S1,E]QK(Attention)V=[N,S1,S2][N,S2,E]=[N,S1,E]
      输出保持与query维度一致
  5. 输出

    • output: 注意力结果,形状与 query 相同
    • attn_weights: 注意力权重矩阵,形状为 (N×h,S,S)(N \times h, S, S)(N×h,S,S)
  6. 典型应用场景

    • 自注意力:query = key = value
    • 编码器-解码器注意力:query 来自解码器,key/value 来自编码器

提示:实际使用时需注意处理 key_padding_mask(屏蔽填充符)和 attn_mask(防止未来信息泄露)。

Python代码nn.MultiheadAttention实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, query, key, value):
        """
        query: [target_len, batch_size, embed_dim]
        key:   [source_len, batch_size, embed_dim]
        value: [source_len, batch_size, embed_dim]
        返回: [target_len, batch_size, embed_dim]
        """
        attn_output, _ = self.multihead_attn(
            query, key, value, 
            need_weights=False
        )
        return attn_output

# 示例用法
if __name__ == "__main__":
    embed_dim = 512
    num_heads = 8
    batch_size = 4
    src_len = 20  # 源序列长度
    tgt_len = 15  # 目标序列长度

    # 初始化模块
    cross_attn = CrossAttention(embed_dim, num_heads)
    
    # 生成模拟数据
    query = torch.rand(tgt_len, batch_size, embed_dim)  # 目标序列
    key = value = torch.rand(src_len, batch_size, embed_dim)  # 源序列
    
    # 执行交叉注意力
    output = cross_attn(query, key, value)
    print("输出张量形状:", output.shape)  # [15, 4, 512]
关键特性
  1. 动态权重分配:根据当前查询动态计算键值序列的权重
  2. 序列长度无关:可处理不同长度的输入序列
  3. 并行计算:矩阵运算支持GPU加速
  4. 可解释性:注意力权重可视化(需修改代码返回权重矩阵)

其中nn.MultiheadAttention 详解

nn.MultiheadAttention 是 PyTorch 中实现多头注意力机制的模块,广泛应用于 Transformer 架构(如 BERT、GPT 等)。其核心思想是将输入拆分为多个"头",并行计算注意力,最后合并结果以捕捉更丰富的上下文信息。

数学原理

设输入序列为 X∈RL×dmodelX \in \mathbb{R}^{L \times d_{\text{model}}}XRL×dmodelLLL 为序列长度,dmodeld_{\text{model}}dmodel 为特征维度)。首先通过线性变换生成查询(Query)、键(Key)、值(Value):
Q=XWQ,K=XWK,V=XWVQ = XW^Q, \quad K = XW^K, \quad V = XW^VQ=XWQ,K=XWK,V=XWV
其中 WQ,WK,WV∈Rdmodel×dkW^Q, W^K, W^V \in \mathbb{R}^{d_{\text{model}} \times d_k}WQ,WK,WVRdmodel×dk

多头注意力将 Q,K,VQ,K,VQ,K,V 拆分为 hhh 个头(每个头维度 dk=dmodel/hd_k = d_{\text{model}}/hdk=dmodel/h),独立计算缩放点积注意力:
headi=Softmax(QiKiTdk)Vi\text{head}_i = \text{Softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right)V_iheadi=Softmax(dkQiKiT)Vi
最终输出是各头结果的拼接与线性变换:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO


关键参数
参数说明
embed_dim输入特征维度 dmodeld_{\text{model}}dmodel
num_heads注意力头数量 hhh
dropout注意力权重 dropout 概率
bias是否在线性层添加偏置
add_bias_kv是否为 K/V 添加额外偏置
kdim, vdim可指定 K/V 与 Q 不同的维度

Cross Attention 自写编码实现

Python 实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k)  # Query 投影
        self.W_k = nn.Linear(d_model, d_k)  # Key 投影
        self.W_v = nn.Linear(d_model, d_v)  # Value 投影

    def forward(self, 
                query: torch.Tensor, 
                key_value: torch.Tensor) -> torch.Tensor:
        """
        query:   [batch_size, m, d_model]
        key_value: [batch_size, n, d_model]
        输出:     [batch_size, m, d_v]
        """
        # 投影变换
        Q = self.W_q(query)       # [batch_size, m, d_k]
        K = self.W_k(key_value)   # [batch_size, n, d_k]
        V = self.W_v(key_value)   # [batch_size, n, d_v]

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch_size, m, n]
        scores /= torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))

        # 归一化权重
        attn_weights = F.softmax(scores, dim=-1)  # [batch_size, m, n]

        # 加权求和
        output = torch.matmul(attn_weights, V)  # [batch_size, m, d_v]
        return output


#### 使用示例
# 参数设置
batch_size, m, n = 4, 10, 20
d_model, d_k, d_v = 512, 64, 64

# 创建数据
query = torch.randn(batch_size, m, d_model)
key_value = torch.randn(batch_size, n, d_model)

# 初始化模块
attn = CrossAttention(d_model, d_k, d_v)

# 前向传播
output = attn(query, key_value)
print(output.shape)  # 输出: torch.Size([4, 10, 64])
关键特性
  1. 维度处理

    • Query 序列长度 mmm 与 Key-Value 序列长度 nnn 可不同
    • 输出维度由 dvd_vdv 控制,与输入 dmodeld_{model}dmodel 解耦
  2. 计算效率

    • 复杂度为 O(m×n)O(m \times n)O(m×n)
    • 使用矩阵乘法并行计算
  3. 缩放因子

    • 1dk\frac{1}{\sqrt{d_k}}dk1 防止点积过大导致 softmax 梯度消失

此实现可直接集成到 Transformer 架构中,适用于机器翻译、文本生成等跨序列建模任务。

典型应用场景
  • 机器翻译:解码器关注编码器输出
  • 图文检索:文本查询关注图像特征
  • 语音识别:声学特征关注语言模型
  • 多模态融合:跨模态特征交互

注:实际使用时需配合层归一化(LayerNorm)和残差连接(Residual Connection)构建完整模块,如Transformer解码器中的交叉注意力层。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一碗白开水一

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值