划时代突破:2017年谷歌论文《Attention Is All You Need》提出的Transformer架构,彻底改变了NLP领域格局。截至2024年,98%的大语言模型基于此架构,包括ChatGPT、Claude等顶级模型。
一、Transformer核心设计:编码器-解码器双模块
机器翻译流程示例(以英译德为例):
-
编码器处理英语句子:"This is an example"
-
生成蕴含上下文信息的向量表示
-
解码器接收向量,逐步生成德语:"Das ist ein Beispiel"
关键革新:完全摒弃RNN/CNN,仅依赖自注意力机制处理序列数据
二、自注意力机制:Transformer的灵魂
传统RNN痛点:
-
长距离依赖衰减(梯度消失)
-
无法并行计算(训练慢)
自注意力解决方案:
数学本质:
比喻理解:如同人类阅读时,大脑自动给不同词语分配关注度。例如:
"苹果公司发布了新款iPhone"
"苹果"获得高权重(公司名)
"发布"中等权重(动作)
"了"低权重(助词)
三、编码器 vs 解码器:BERT与GPT的分水岭
组件 | 层结构 | 核心功能 | 代表模型 |
---|---|---|---|
编码器 | 自注意力层 + FFN | 理解文本语义 | BERT, RoBERTa |
解码器 | 掩码自注意力层 + FFN | 生成连贯文本 | GPT系列, LLaMA |
3.1 编码器架构(BERT类模型)
训练目标:掩码语言建模(Masked Language Modeling)
示例:
输入:"科技[MASK]改变世界" → 预测:"技术"
3.2 解码器架构(GPT类模型)
关键限制:掩码机制确保当前位置只能看到左侧信息
训练目标:因果语言建模(Causal Language Modeling)
示例:
输入:"人工智能将" → 预测:"改变未来"
四、Transformer如何驱动现代LLM
4.1 模型演进图谱
4.2 任务适应能力对比
能力 | BERT类模型 | GPT类模型 | 示例场景 |
---|---|---|---|
文本分类 | ★★★★★ | ★★★☆☆ | 情感分析/垃圾邮件检测 |
文本生成 | ★★☆☆☆ | ★★★★★ | 小说创作/代码补全 |
问答系统 | ★★★★☆ | ★★★★☆ | 知识库问答 |
零样本学习 | ★★☆☆☆ | ★★★★★ | 未知任务泛化 |
工业案例:Twitter(现X平台)使用BERT过滤有害内容,处理速度达5万条/秒
五、视觉Transformer:突破NLP边界
架构创新:
# ViT (Vision Transformer) 图像处理流程
import torch
from vit_pytorch import ViT
model = ViT(
image_size=256, # 输入图像尺寸
patch_size=32, # 分块大小
num_classes=1000, # 分类数
dim=1024, # 向量维度
depth=6, # Transformer层数
heads=16, # 注意力头数
mlp_dim=2048
)
# 将图像分割为8x8=64个patch
img = torch.randn(1, 3, 256, 256)
preds = model(img) # 输出分类结果
性能对比(ImageNet准确率):
模型 | 参数量 | Top-1 Acc | 训练成本 |
---|---|---|---|
ResNet-50 | 25M | 76.5% | 1x |
ViT-Large | 307M | 85.2% | 3.2x |
Swin-B | 88M | 86.4% | 2.7x |
六、高效Transformer变体:解决计算瓶颈
三大优化方向:
计算优化技术 | 占比(%) |
---|---|
稀疏注意力 | 45 |
混合精度训练 | 30 |
模型蒸馏 | 25 |
明星架构对比:
变体名称 | 核心创新 | 速度提升 | 适用场景 |
---|---|---|---|
FlashAttention | 硬件感知IO优化 | 4.2x | 所有Transformer |
Linformer | 低秩投影 | 3.1x | 长文本处理 |
Performer | 随机正交特征映射 | 2.8x | 实时推理系统 |
FlashAttention代码示例:
# 安装: pip install flash-attn
from flash_attn import flash_attention
Q = torch.randn(4, 64, 1024) # [batch, seq_len, dim]
K = torch.randn(4, 64, 1024)
V = torch.randn(4, 64, 1024)
output = flash_attention(Q, K, V) # 比标准Attention快4倍
七、PyTorch实现核心组件(可运行代码)
7.1 自注意力层实现
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0] # 批大小
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 分割嵌入向量到多个注意力头
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 计算注意力分数
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
# 应用注意力权重
out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
out = out.reshape(N, query_len, self.heads * self.head_dim)
return self.fc_out(out)
7.2 Transformer块集成
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super().__init__()
self.attention = SelfAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, forward_expansion * embed_size),
nn.ReLU(),
nn.Linear(forward_expansion * embed_size, embed_size)
)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
八、未来趋势:Transformer的挑战者
2024年新架构评测:
架构 | 计算复杂度 | 长文本支持 | 硬件适配性 | 代表模型 |
---|---|---|---|---|
Transformer | O(n²) | ★★★☆☆ | ★★★★☆ | GPT-4 |
RetNet | O(n) | ★★★★★ | ★★★☆☆ | RetNet-7B |
Mamba | O(n) | ★★★★☆ | ★★★★☆ | Mamba-130B |
RWKV | O(n) | ★★★★☆ | ★★★★☆ | RWKV-5 |