《从零构建大模型》系列(4):Transformer架构——大语言模型的心脏引擎

 

划时代突破:2017年谷歌论文《Attention Is All You Need》提出的Transformer架构,彻底改变了NLP领域格局。截至2024年,98%的大语言模型基于此架构,包括ChatGPT、Claude等顶级模型。


一、Transformer核心设计:编码器-解码器双模块

机器翻译流程示例(以英译德为例):

  1. 编码器处理英语句子:"This is an example"

  2. 生成蕴含上下文信息的向量表示

  3. 解码器接收向量,逐步生成德语:"Das ist ein Beispiel"

关键革新:完全摒弃RNN/CNN,仅依赖自注意力机制处理序列数据


二、自注意力机制:Transformer的灵魂

传统RNN痛点

  • 长距离依赖衰减(梯度消失)

  • 无法并行计算(训练慢)

自注意力解决方案

数学本质

 

比喻理解:如同人类阅读时,大脑自动给不同词语分配关注度。例如:
"苹果公司发布了新款iPhone"

  • "苹果"获得高权重(公司名)

  • "发布"中等权重(动作)

  • "了"低权重(助词)


三、编码器 vs 解码器:BERT与GPT的分水岭

组件层结构核心功能代表模型
编码器自注意力层 + FFN理解文本语义BERT, RoBERTa
解码器掩码自注意力层 + FFN生成连贯文本GPT系列, LLaMA

3.1 编码器架构(BERT类模型)

训练目标:掩码语言建模(Masked Language Modeling)
示例:
输入:"科技[MASK]改变世界" → 预测:"技术"

3.2 解码器架构(GPT类模型)

关键限制:掩码机制确保当前位置只能看到左侧信息
训练目标:因果语言建模(Causal Language Modeling)
示例:
输入:"人工智能将" → 预测:"改变未来"


四、Transformer如何驱动现代LLM

4.1 模型演进图谱

4.2 任务适应能力对比

能力BERT类模型GPT类模型示例场景
文本分类★★★★★★★★☆☆情感分析/垃圾邮件检测
文本生成★★☆☆☆★★★★★小说创作/代码补全
问答系统★★★★☆★★★★☆知识库问答
零样本学习★★☆☆☆★★★★★未知任务泛化

工业案例:Twitter(现X平台)使用BERT过滤有害内容,处理速度达5万条/秒


五、视觉Transformer:突破NLP边界

架构创新

# ViT (Vision Transformer) 图像处理流程
import torch
from vit_pytorch import ViT

model = ViT(
    image_size=256,       # 输入图像尺寸
    patch_size=32,        # 分块大小
    num_classes=1000,     # 分类数
    dim=1024,             # 向量维度
    depth=6,              # Transformer层数
    heads=16,             # 注意力头数
    mlp_dim=2048
)

# 将图像分割为8x8=64个patch
img = torch.randn(1, 3, 256, 256) 
preds = model(img)  # 输出分类结果

性能对比(ImageNet准确率):

模型参数量Top-1 Acc训练成本
ResNet-5025M76.5%1x
ViT-Large307M85.2%3.2x
Swin-B88M86.4%2.7x

六、高效Transformer变体:解决计算瓶颈

三大优化方向

计算优化技术占比(%)
稀疏注意力45
混合精度训练30
模型蒸馏25

明星架构对比

变体名称核心创新速度提升适用场景
FlashAttention硬件感知IO优化4.2x所有Transformer
Linformer低秩投影3.1x长文本处理
Performer随机正交特征映射2.8x实时推理系统

FlashAttention代码示例

# 安装: pip install flash-attn
from flash_attn import flash_attention

Q = torch.randn(4, 64, 1024)  # [batch, seq_len, dim]
K = torch.randn(4, 64, 1024)
V = torch.randn(4, 64, 1024)

output = flash_attention(Q, K, V)  # 比标准Attention快4倍

七、PyTorch实现核心组件(可运行代码)

7.1 自注意力层实现

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
    
    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # 批大小
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 分割嵌入向量到多个注意力头
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        
        # 应用注意力权重
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        out = out.reshape(N, query_len, self.heads * self.head_dim)
        
        return self.fc_out(out)

7.2 Transformer块集成

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

八、未来趋势:Transformer的挑战者

2024年新架构评测

架构计算复杂度长文本支持硬件适配性代表模型
TransformerO(n²)★★★☆☆★★★★☆GPT-4
RetNetO(n)★★★★★★★★☆☆RetNet-7B
MambaO(n)★★★★☆★★★★☆Mamba-130B
RWKVO(n)★★★★☆★★★★☆RWKV-5
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Sonal_Lynn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值