multi-heads attention 机制和代码详解

Self-Attention

说下面的句子是我们要翻译的输入句子:

”The animal didn’t cross the street because it was too tired”

这句话中的“它”指的是什么? 是指街道还是动物? 对人类来说,这是一个简单的问题,但对算法而言却不那么简单。

当模型处理单词“ it”时,自我关注使它可以将“ it”与“ animal”相关联。

在模型处理每个单词(输入序列中的每个位置)时,自我关注使其能够查看输入序列中的其他位置以寻找线索,从而有助于更好地对该单词进行编码。

如果您熟悉RNN,请考虑一下如何通过保持隐藏状态来使RNN将其已处理的先前单词/向量的表示形式与当前正在处理的单词/向量进行合并。 Transformer使用self attention将其他相关单词的“理解”融入到我们当前正在处理的单词中的过程中。

在这里插入图片描述

计算 Self-Attention 的 attention

首先,让我们看看如何使用向量来计算自我注意力,然后着眼于如何使用矩阵来实现自我注意力。

计算自我注意力的第一步是从每个编码器的输入向量中创建三个向量(在这种情况下,是每个单词的嵌入)。 因此,对于每个单词,我们创建一个查询向量,一个键向量和一个值向量。 这些向量是通过将嵌入乘以我们在训练过程中训练的三个矩阵而创建的。

请注意,这些新向量的维数小于嵌入向量的维数。 它们的维数为64,而嵌入和编码器输入/输出矢量的维数为512。它们不必较小,这是使多头注意力(大部分)计算保持恒定的体系结构选择。

在这里插入图片描述

什么是“query”,“key”和“vector”向量?

它们是抽象,对于计算和思考注意力非常有用。计算自我注意力的第二步是计算score。 假设我们正在计算此示例“Thinking”中第一个单词的self attention。 我们需要根据该单词对输入句子的每个单词score。 score 决定了当我们在某个位置对单词进行编码时,将attention 集中在输入句子的其他部分上的程度。 Score 是是通过将“query”向量和各个单词“key”向量的点积得出的。

在这里插入图片描述

第三和第四步是将score除以8(本文中使用的“key”向量的维数的平方根–64。这将导致梯度更稳定。此处可能存在其他可能的值,但这是 默认值),然后通过softmax操作传递结果。 Softmax对分数进行归一化,使它们均为正并加1。

在这里插入图片描述

这个softmax score确定每个单词在此位置将被表达多少。 显然,该位置的单词的softmax得分最高,但是 我们也同时需要用attention 去关注其他相关的单词,这要用到multi heads attentions。

第五步是将每个值向量乘以softmax分数(准备将它们相加)。 直觉是保持我们要关注的单词的值完整,并淹没无关的单词(例如,将它们乘以0.001之类的小数字)。

第六步是对加权向量进行求和。 这将在此位置(对于第一个单词)产生自我注意层的输出。

在这里插入图片描述

我们可以发送生成的向量到前馈神经网络。 但是,在实际实现中,此计算以矩阵形式进行,以加快处理速度。

Self-Attention 矩阵的计算

第一步是计算“query”和“key”值的矩阵。 为此,我们将嵌入内容打包到矩阵X中,然后将其乘以我们训练过的权重矩阵(WQ,WK,WV)。

在这里插入图片描述

最后,由于我们要处理矩阵,因此我们可以将步骤2到6压缩成一个公式,以计算自我注意层的输出。

“multi-headed” attention在这里插入图片描述

如果我们执行上面概述的相同的自注意力计算,最终将得到2个不同的Z矩阵 在这里插入图片描述

这给我们带来了一些挑战。 前馈层只要有一个矩阵(每个单词一个向量)。 因此,我们需要一种将这2个矩阵压缩为一个矩阵的方法。

我们该怎么做? 我们合并矩阵,然后将它们乘以其他权重矩阵WO。
在这里插入图片描述

multi heads attention 的计算过程如下:
在这里插入图片描述

例如 这个例子中我们有8个attention heads,第一个attention head的注意力显示 it 和 because 最相关,第二个attention head的注意力显示 it 和 cross 最相关,等等…
在这里插入图片描述

multi-heads attention 的代码

这里我们用一个文本2分类的任务融合 attention机制来解释 multi heads attention理论
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
import random

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(6688)
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
data_path = 'F:/data/'

from collections import Counter

def build_vocab(sents, max_words=50000):
    word_counts = Counter()
    for word in sents:
        word_counts[word] += 1
    itos = [w for w, c in word_counts.most_common(max_words)]
    itos = itos + ["UNK"]
    stoi = {w: i for i, w in enumerate(itos)}
    return itos, stoi


tokenize = lambda x: x.split()
text = open(data_path + 'senti.train.tsv').read()
vob = tokenize(text.lower())
itos, stoi = build_vocab(vob)
itos[0:5]
['1', '0', 'the', ',', 'a']
设计数据集
class Corpus:
    def __init__(self, data_path, sort_by_len=False):
        self.vocab = vob
        self.sort_by_len = sort_by_len
        self.train_data, self.train_label = self.tokenize(data_path + 'train.tsv')
        self.valid_data, self.valid_label = self.tokenize(data_path + 'dev.tsv')
        self.test_data, self.test_label = self.tokenize(data_path + 'test.tsv')

    def tokenize(self, text_path):
        with open(text_path) as f:
            index_data = []  # 索引数据,存储每个样本的单词索引列表
            labels = []
            for line in f.readlines():
                sentence, label = line.split('\t')
                index_data.append(
                    self.sentence_to_index(sentence.lower())
                )
                labels.append(
                    int(label[0])
                )
        if self.sort_by_len:  # 为了提升训练速度,可以考虑将样本按照长度排序,这样可以减少padding
            index_data = sorted(index_data, key=lambda x: len(x), reverse=True)
        return index_data, labels

    def sentence_to_index(self, s):
        a = []
        for w in s.split():
            if w in stoi.keys():
                a.append(stoi[w])
            else:
                a.append(stoi["UNK"])
        return a

    def index_to_sentence(self, x):
        return ' '.join([itos[i] for i in x])

corpus = Corpus(data_path, sort_by_len=False)
设计batches
def get_minibatches(text_idx, labels, batch_size=64, sort=False):
    if sort:
        text_idx_and_labels = sorted(list(zip(text_idx, labels)), key=lambda x: len(x[0]))
    else:
        text_idx_and_labels = (list(zip(text_idx, labels)))
    text_idx_batches = []
    label_batches = []
    for i in range(0, len(text_idx), batch_size):
        text_batch = [t for t, l in text_idx_and_labels[i:i + batch_size]]
        label_batch = [l for t, l in text_idx_and_labels[i:i + batch_size]]
        text_idx_batches.append(text_batch)
        label_batches.append(label_batch)
    return text_idx_batches, label_batches
BATCH_SIZE = 256
VOCAB_SIZE = len(itos)
EMBEDDING_SIZE = 256
OUTPUT_SIZE = 1

train_batches, train_label_batches = get_minibatches(corpus.train_data, corpus.train_label, BATCH_SIZE)
dev_batches, dev_label_batches = get_minibatches(corpus.valid_data, corpus.valid_label, BATCH_SIZE)
test_batches, test_label_batches = get_minibatches(corpus.test_data, corpus.test_label, BATCH_SIZE)

设计attention 中的 positional encoding
import math

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)
设计attention score 的叉乘操作
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        # Q: [ batch_size ,n_heads ,seq_length ,d_k ]
        # K: [ batch_size ,n_heads ,seq_length ,d_k ]
        # V: [ batch_size ,n_heads ,seq_length ,d_k ]

        # scores: [ batch_size ,n_heads ,seq_length ,seq_length ]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)

        # Fills elements of self tensor with value where mask is one.
        scores.masked_fill_(attn_mask, -1e9)

        # attn: [ batch_size ,n_heads ,seq_length,seq_length ]
        attn = nn.Softmax(dim=-1)(scores)

        # context: [batch_size , n_heads ,seq_length, d_k]
        Z = torch.matmul(attn, V)

        return Z
设计多头的attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, d_k, n_heads):
        super(MultiHeadAttention, self).__init__()
        # self.W_Q.weight: [ n_heads*d_k, EMBEDDING_SIZE ]
        self.W_Q = nn.Linear(EMBEDDING_SIZE, n_heads * d_k)

        # self.W_K.weight: [ n_heads*d_k, EMBEDDING_SIZE ]
        self.W_K = nn.Linear(EMBEDDING_SIZE, n_heads * d_k)

        # self.W_V.weight: [ n_heads*d_k, EMBEDDING_SIZE ]
        self.W_V = nn.Linear(EMBEDDING_SIZE, n_heads * d_k)

        self.n_heads = n_heads
        self.d_model = embedding_size
        self.d_k = d_k

    def forward(self, Q, attn_mask):
        # q: [batch_size,seq_length, EMBEDDING_SIZE]
        # residual, batch_size = Q, Q.size(0)

        batch_size = Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)

        # q_s: [batch_size, n_heads, seq_length, d_k]
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)

        # k_s: [batch_size, n_heads, seq_length, d_k]
        k_s = self.W_K(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)

        # v_s: [batch_size, n_heads, seq_length, d_k]
        v_s = self.W_V(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)

        attn_mask = attn_mask.eq(0)

        # attn_mask : [batch_size, n_heads, seq_length, seq_length]
        attn_mask = attn_mask.unsqueeze(1).unsqueeze(3).repeat(1, self.n_heads, 1, k_s.size(2))

        # Z : [batch_size, n_heads, seq_length, d_k]
        Z = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)

        # Z : [batch_size , seq_length , n_heads * d_k]
        Z = Z.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.n_heads)
        
        # output : [batch_size , seq_length , embedding_size]
        output = nn.Linear(self.d_k * self.n_heads, self.d_model).to(device)(Z)

        return output
设置 attention的头数量 以及 q k v的维度
d_k = 4  # dimension of K(=Q), V
heads_num = 2
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size, dropout_p=0.5):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size)
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.embed_words = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, output_size)
        self.dropout = nn.Dropout(dropout_p)
        self.attentions = MultiHeadAttention(embedding_size, d_k, heads_num)
        self.Pos = PositionalEncoding(embedding_size, dropout_p, max_len=5000)

设计attention 加权平均模型
class WordAVGModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, output_size, dropout_p=0.5):
        super(WordAVGModel, self).__init__()
        self.embedding_size = embedding_size
        self.output_size = output_size
        self.encoder = Encoder(vocab_size, embedding_size, output_size, dropout_p)

    def forward(self, text, mask):
        # text: [batch_size * max_seq_len]  mask: [batch_size * max_seq_len]       

        # embedded: [batch_size, max_seq_len, embedding_size]
        embedded = self.encoder.embed(text)
        
        # embedded: [batch_size, max_seq_len, embedding_size]
        #embedded = self.encoder.Pos(embedded)
        
        # embedded: [batch_size, max_seq_len, embedding_size]
        embedded = self.encoder.dropout(embedded) 
      
        # enc_inputs to same Q,K,V 为模型加入 multi-heads attention
        # a_ts: [batch_size , max_seq_len , embedding_size]
        a_ts = self.encoder.attentions(embedded,mask)        

        # a_t: [batch_size , max_seq_len]
        a_t = torch.sum(a_ts,2)  
        
        # a_t: [batch_size , max_seq_len]
        a_t = torch.softmax(a_t, dim=1)        
        
        # h_self: [batch_size ,embedding_size]
        h_self = torch.bmm(a_t.unsqueeze(1), embedded).squeeze()    
        
        # mask: [batch_size, max_seq_len, 1], 1 represents word, 0 represents padding
        mask = mask.float().unsqueeze(2)
        
        # embedded: [batch_size, max_seq_len, embedding_size]
        embedded = embedded * mask

        # h_av: [batch_size, embedding_size]
        h_av = embedded.sum(1) / (mask.sum(1) + 1e-9)  # 防止mask.sum为0,那么不能除以零。      
        
        # out: [batch_size, output_size]
        out = self.encoder.linear(h_self)
        #out = self.encoder.linear(h_self + h_av)
        
        return out
    
model = WordAVGModel(vocab_size=VOCAB_SIZE,
                     embedding_size=EMBEDDING_SIZE,
                     output_size=OUTPUT_SIZE,
                     dropout_p=0.5)

optimizer = torch.optim.Adam(model.parameters())
crit = nn.BCEWithLogitsLoss()
model = model.to(device)
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

def train(model, text_idxs, labels, optimizer, crit):
    epoch_loss, epoch_acc = 0., 0.
    model.train()
    total_len = 0.
    for text, label in zip(text_idxs, labels):
        text = [torch.tensor(x).long().to(device) for x in (text)]
        label = [torch.tensor(label).long().to(device)]
        lengths = torch.tensor([len(x) for x in text]).long().to(device)
        text = nn.utils.rnn.pad_sequence(text, batch_first=True)
        mask = (text != 0).float().to(device)

        # 在之后的训练中因为还要进行pack_padded_sequence操作,所以在这里按照长度降序排列
        lengths, perm_index = lengths.sort(descending=True)
        text = text[perm_index]
        label = label[0][perm_index]

        preds = model(text, mask).squeeze()  # [batch_size, sent_length]
        loss = crit(preds, label.float())
        acc = binary_accuracy(preds, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * len(label)
        epoch_acc += acc.item() * len(label)
        total_len += len(label)

    return epoch_loss / total_len, epoch_acc / total_len

def evaluate(model, text_idxs, labels, crit):
    epoch_loss, epoch_acc = 0., 0.
    model.train()
    total_len = 0.
    for text, label in zip(text_idxs, labels):
        text = [torch.tensor(x).long().to(device) for x in (text)]
        label = [torch.tensor(label).long().to(device)]
        lengths = torch.tensor([len(x) for x in text]).long().to(device)
        text = nn.utils.rnn.pad_sequence(text, batch_first=True)
        mask = (text != 0).float().to(device)
        pad_attn_mask = mask.eq(0).unsqueeze(1)

        # 在之后的训练中因为还要进行pack_padded_sequence操作,所以在这里按照长度降序排列
        lengths, perm_index = lengths.sort(descending=True)
        text = text[perm_index]
        label = label[0][perm_index]
        with torch.no_grad():
            preds = model(text, mask).squeeze()  # [batch_size, sent_length]
        loss = crit(preds, label.float())
        acc = binary_accuracy(preds, label)
        epoch_loss += loss.item() * len(label)
        epoch_acc += acc.item() * len(label)
        total_len += len(label)

    return epoch_loss / total_len, epoch_acc / total_len
训练模型
N_EPOCHS = 30
best_valid_acc = 0.
record = 0
for epoch in range(N_EPOCHS):
    train_loss, train_acc = train(model, train_batches, train_label_batches, optimizer, crit)
    valid_loss, valid_acc = evaluate(model, dev_batches, dev_label_batches, crit)

    if valid_acc > best_valid_acc:
        record = 0
        best_valid_acc = valid_acc
        torch.save(model.state_dict(), "wordavg-model-Adam.pth")
    else:
        record += 1
    if record > 30:
        print("early stopping at epoch", epoch)
        break
    if epoch % 5 == 0:
        print("Epoch", epoch, "Train Loss", train_loss, "Train Acc", train_acc)
        print("Epoch", epoch, "Valid Loss", valid_loss, "Valid Acc", valid_acc)
Epoch 0 Train Loss 0.6855089499452708 Train Acc 0.5557145779421909
Epoch 0 Valid Loss 0.6951910366705798 Valid Acc 0.5091743135671003
Epoch 5 Train Loss 0.5996944122435344 Train Acc 0.7028215090279393
Epoch 5 Valid Loss 0.5999153551705386 Valid Acc 0.7075688084331128
Epoch 10 Train Loss 0.4333473568592937 Train Acc 0.8600490696955205
Epoch 10 Valid Loss 0.4965302976993246 Valid Acc 0.7694954188591844
Epoch 15 Train Loss 0.32046079018673657 Train Acc 0.914127990246989
Epoch 15 Valid Loss 0.46561843430230376 Valid Acc 0.7970183519048428
Epoch 20 Train Loss 0.24862328401142725 Train Acc 0.9407074219169999
Epoch 20 Valid Loss 0.4701578086669292 Valid Acc 0.7935779838387026
Epoch 25 Train Loss 0.20012277661322134 Train Acc 0.9533837661619349
Epoch 25 Valid Loss 0.4839585598455657 Valid Acc 0.7970183535453377
预测test
model.load_state_dict(torch.load("wordavg-model-Adam.pth"))
test_loss, test_acc = evaluate(model, test_batches, test_label_batches, crit)
print("Test Loss", test_loss, "Test Acc", test_acc)
Test Loss 0.4726336712957672 Test Acc 0.7935200437027828

<think>我们正在处理用户关于多头注意力(Multi-head Attention)模型结构图的请求。根据用户需求,我们应提供相关可视化示意图或结构图。由于文本交互的限制,我们无法直接显示图像,但可以描述其结构,并建议如何获取或生成这样的图。 根据引用[2][3]: - 多头注意力机制是将自注意力(Self-Attention)重复多次(h次),每次使用不同的线性变换(即不同的权重矩阵),然后将h次的结果拼接起来,再进行一次线性变换得到最终输出。 结构描述: 1. 输入:输入序列(例如,词嵌入向量)经过三个不同的线性变换(权重矩阵)得到查询(Q)、键(K)、值(V)向量。 2. 多头:将Q、K、V按头的数量h进行分割(通常是通过线性变换投影到较低维度),然后对每个头分别计算自注意力。 3. 自注意力计算:每个头独立计算注意力分数(通常是缩放点积注意力),公式为: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ 其中$d_k$是键向量的维度。 4. 拼接:将h个头的输出拼接(concatenate)起来。 5. 线性变换:将拼接后的结果通过一个线性变换(权重矩阵$W^O$)得到多头注意力的最终输出。 结构图通常包括以下部分: - 输入(Input)经过三个线性层(分别对应Q、K、V的投影)。 - 每个头(Head)的拆分独立计算。 - 每个头内的自注意力机制(包括softmax点积)。 - 多个头的输出拼接。 - 最后的线性变换(输出投影)。 由于无法直接提供图片,我们可以: 1. 建议用户搜索“multi-head attention architecture diagram”或“transformer multi-head attention diagram”,会有大量现成的图示。 2. 提供代码生成结构图(如果环境支持绘图): - 使用PyTorch或TensorFlow的模型摘要功能(如`torchsummary`)可以输出文本结构。 - 使用`graphviz`或`torchviz`绘制模型计算图(需要安装相应库)。 3. 用文本描述一个简化的结构图(ASCII art),但可能不够直观。 这里尝试用文本简单描绘: ``` 输入 (Input) │ ├─线性变换 W_q ──► 查询向量 Q ──分割为h个头─┬─► 头1 (计算自注意力) ──┐ ├─线性变换 W_k ──► 键向量 K ──分割为h个头─┼─► 头1 (计算自注意力) ─┤ └─线性变换 W_v ──► 值向量 V ──分割为h个头─┴─► 头1 (计算自注意力) ─┘ ... ... ┌───────────────────────────────────────────────────────────────────┐ │ 每个头独立计算:Attention(Q_i, K_i, V_i) = softmax(Q_i*K_i^T/sqrt(d_k))*V_i │ └───────────────────────────────────────────────────────────────────┘ │ ▼ 拼接 (Concat) [头1输出, 头2输出, ..., 头h输出] │ ▼ 线性变换 W_o │ ▼ 输出 (Output) ``` 注意:实际实现中,通常不会真的分割输入,而是通过矩阵运算同时计算多个头(通过调整矩阵维度)。 另外,根据引用[4],我们可以使用PyTorchTransformers库来实现多头注意力。以下是一个简单的多头注意力层的PyTorch实现,可以帮助理解结构: ```python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # 线性变换层,用于生成Q、K、V self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): batch_size, seq_len, _ = x.size() # 生成Q、K、V Q = self.W_q(x) # [batch_size, seq_len, d_model] K = self.W_k(x) V = self.W_v(x) # 将d_model维度拆分为num_heads * head_dim,并调整维度以便并行计算多个头 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2) # [batch_size, num_heads, seq_len, head_dim] K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2) # 计算缩放点积注意力 attn_scores = torch.matmul(Q, K.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) attn_probs = torch.softmax(attn_scores, dim=-1) attn_output = torch.matmul(attn_probs, V) # [batch_size, num_heads, seq_len, head_dim] # 将多头输出拼接 attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, seq_len, self.d_model) # 通过线性变换得到最终输出 output = self.W_o(attn_output) return output ``` 这个类实现了多头注意力的基本结构。用户可以通过实例化这个类并传入输入张量来运行。 为了生成结构图,我们可以使用`torchviz`来绘制计算图(需要安装`torchviz``graphviz`): ```python from torchviz import make_dot # 创建模型实例 model = MultiHeadAttention(d_model=512, num_heads=8) # 随机输入 x = torch.randn(1, 10, 512) # [batch_size, seq_len, d_model] # 前向传播 output = model(x) # 绘制计算图 dot = make_dot(output, params=dict(model.named_parameters())) dot.render("multi_head_attention", format="png") ``` 这样会生成一个名为`multi_head_attention.png`的文件,展示计算图。但请注意,计算图可能非常详细(包含所有操作),而不是通常见到的简化结构图。 因此,我们更推荐用户参考已有的示意图(例如原始Transformer论文中的图,或知名博客如The Illustrated Transformer)以获得更清晰的可视化。 最后,生成相关问题: §§相关问题§§ 1. 多头注意力中每个头的维度为什么要设置为d_model/h? 2. 在多头注意力中,为什么需要缩放点积注意力(除以$\sqrt{d_k}$)? 3. 多头注意力与单头注意力相比有什么优势? 4. 如何理解多头注意力中的“头”? 5. 在Transformer模型中,多头注意力是如何应用的?</think>### Multi-Head Attention 模型架构详解 多头注意力(Multi-Head Attention)是Transformer架构的核心组件,其结构可描述如下: #### 基本结构图 ``` 输入序列 (Input Sequence) │ ├─线性投影层 (Linear Projections)─┬─► 头1 (Head 1) ─┐ │ ├─► 头2 (Head 2) ─┤ │ ... │ │ └─► 头h (Head h) ─┘ │ │ │ ▼ │ 缩放点积注意力 (Scaled Dot-Product Attention) │ 公式:$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $ │ │ └────────────────────────────────────▼ 拼接 (Concat) │ ▼ 输出投影层 (Output Projection) │ ▼ 输出序列 (Output Sequence) ``` #### 关键组件解析 1. **线性投影层** 输入序列通过三个不同的权重矩阵($W^Q, W^K, W^V$)投影,生成查询(Query)、键(Key)、值(Value)向量: $$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$ 其中$X$是输入矩阵,维度为$n \times d_{\text{model}}$(n为序列长度)[^2]。 2. **多头分割** 投影后的$Q/K/V$被分割为$h$个头(通常$h=8$),每个头的维度降为$d_k = d_{\text{model}}/h$: $$ \text{Head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$ 3. **自注意力计算** 每个头独立计算注意力权重: $$ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ 缩放因子$\sqrt{d_k}$防止点积过大导致梯度消失[^2][^3]。 4. **拼接与输出** $h$个头的输出拼接后通过线性变换: $$ \text{MultiHead}(Q,K,V) = \text{Concat}(\text{Head}_1, ..., \text{Head}_h)W^O $$ 输出维度恢复为$d_{\text{model}}$[^3]。 #### 可视化建议 1. **经典示意图**: ![多头注意力架构](https://miro.medium.com/v2/resize:fit:1400/1*O0Ri6-qt5zVfHLVQY1oFbw.png) *(来源:Jay Alammar "The Illustrated Transformer")* 2. **交互式可视化工具**: - [Tensor2Tensor 注意力可视化](https://nlp.seas.harvard.edu/2018/04/03/attention.html) - [BERTViz 工具](https://github.com/jessevig/bertviz) 3. **PyTorch实现核心代码**(基于引用[4]): ```python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_k = d_model // num_heads self.num_heads = num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): # 投影到Q/K/V空间 Q = self.W_q(x) # [batch, seq_len, d_model] K = self.W_k(x) V = self.W_v(x) # 分割多头 Q = Q.view(-1, seq_len, self.num_heads, self.d_k).transpose(1,2) K = K.view(-1, seq_len, self.num_heads, self.d_k).transpose(1,2) V = V.view(-1, seq_len, self.num_heads, self.d_k).transpose(1,2) # 计算注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k)) attn_weights = torch.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, V) # 拼接多头输出 attn_output = attn_output.transpose(1,2).contiguous().view(-1, seq_len, d_model) return self.W_o(attn_output) ``` #### 架构特点 - **并行计算**:多个注意力头可并行计算,提升效率 - **多样化表示**:不同头学习不同的语义关系(如语法/指代关系) - **残差连接**:实际实现中会添加残差连接层归一化[^4]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值