【论文阅读笔记】Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

一句话总结

<缝合>了之前几种长序列选择注意力机制的<半稀疏> 注意力运算模块,并把算法写到了Triton 上,总体上,模型的准确率表现有提升但不巨大,运算效率上提升显著。

方法概览(算法部分)

** 本篇blog主要是描述Native Sparse Attention在Attention模块运算上的改造。他们这篇文章还有一个重要改造就在GPU的运算逻辑上。这部分后面单开一篇整理

主要改动在于减少Attention中 KV的长度,为此做了三个动作

1. 整个序列先切块,按照块压缩,再做 attention 运算,形成 result1
2. 截选优质的块,做
3. 选最近一个块,做qk(不跟1.2 混着算)

在这里插入图片描述
另外这个图其实跟一些同类论文的画图顺序有点不一样,一般讲transformer结构改造的论文 数据流多数是从下往上画的(就是输入的xToken序列在图的最下面,然后通过若干运算到达图的最上面),他这个是从上往下画,所以乍一看容易让人蒙一下,但实际表述还是很清楚的。

展开说说

下面会统一用block来表示块(之前的长序列transformer的论文里,通常会使用chunk,我这里猜测主要是为了和后面应建预算那里的术语打通,所以统一用的block)

分Block压缩

操作上,简单的说

  1. 使用rolling window strategy(时序领域里常用这个说法,换成图像领域其实就相当于是1D convolution)把序列切成若干个Block。
  2. 把每个Block都拍扁了:用MLP在seq维度上做降维映射。
  3. 把拍扁了的block连成一个新的短序列和Q做QK的运算
比如一个seq_len=24的 token sequence,shape 是[batch_size, 24, hidden_dim]
第一步:按照窗口长度=8,stride=4,切成5个快
第二步:每个块用一个MLP压缩成这个shape [..., 1 , hidden_dim]
第三步:这样形成的K^cmp,其shape为[batch_size, 5, hidden_dim] ,做QK运算
选择优质Block

虽然文中的名字是TokenSelection,但是这个称呼其实比较容易混淆。因为整个操作都是在Block这个水平上操作的。
操作上,仍然是高度简化版本(后面再说我主要简化了什么)

  1. 还是先把全系列切成若干个Block
  2. 每个Block用MLP压缩
  3. 用压缩过的KcmpK^{cmp}KcmpQQQ做QK的运算,形成一个<重要性分>
  4. 选出<重要性分>最大的top-k个Block 并生成一个Token level的 attention mask
  5. QQQKKK 和4 中做出来的attention mask 做QK的运算(这样就相当于是Q和仅有的K个Block的 K tokens 做QK的运算,不单运算少,还好并行)

说说我简化了什么

  • 从操作步骤上看,这个部分的1-3 是和上面分Block压缩是一致的,那在实践中也确实没必要算两遍,但是压缩块的大小和优质块的大小是可以不一样的,当不一样的时候,又不想重算应该怎么做映射呢?我把论文里的公式贴出来
    在这里插入图片描述
  • 针对在GQA MQA这类常用的attention变种怎么执行,这部分跟上面那部分在论文里相近的位置。
  • 实际生成过程中,序列seq_len 总是不能被llll′l'l 整除,怎么处理
拉最近一个Block

文章中之所以称为sliding window,主要是因为随着生成的进度,这个窗口总是当前生成token最近的一个Block。不过用sliding window反而容易产生混淆。
这里操作就一个,就是取最近的n个token作为K,V,和Q做QK的运算。

总结

这回再看作者的图

  • 最左边的数据流 是把他token上蓝色的四个区域压缩成四块算attention
  • 中间的部分 是借用了左边压缩块的结果,把压缩块中深绿色的Block里的所有token concat到一起形成一个短序列 ,然后和Q做attention
  • 右侧 是最近的一个Block的token单独拿出来和Q做attention。
  • 最右侧的三个网格–>而在真正执行的时候,是按照上面的逻辑算出三个mask来和做attention的运算函数对接的。

评价

从作者角度出发

  1. 作者的初衷是从优化角度上考虑的,主要目标还是在运算和存储都能打满的角度设计运算kernel
  2. 作者提到的trainable主要是压缩MLP这部分会随着训练而一起优化比固定规则强很多。

从同行的角度出发

  1. 之所以我称他为<半稀疏>,主要是这里取的Token基本上都是连续在某个窗口内的,不像Informer/LogFormer等等在Token级别抽取的方案,当然这在很多任务上都是非常合理的,而且也是并行效率很高的一种做法。
  2. 之所以我称他<缝合>了很多同类方法,主要是在时序和长文本领域 分块,分块压缩,选top-k,选最近一个window,其实都是21年前后大量时序transformer玩儿过的方案。

特别是分块和选择topk重要的块进行token concat这部分,21年我被IJCAI拒了的时序采样网络写的就是这个东西(当然我没那个实力连运算一起优化了,我就当审稿人拒我很有理吧)

  1. 分块机制的合理性:由于softmax的机制

导致在超长序列上做Attention甚至可以约等价于在超长序列上做RAG。

即便是把完整的KV算了,但在softmax影响下还是会丢很多信息,尤其是在相对位置编码加成下,更是如此。因此,从减少运算角度上说这个方法是非常合理的
4. 分块优选的隐含风险:其实是有两个,一个运算处理不够简单,这也是在长文本领域里众多Transformer变种最终都没有打败直接并行运算的原因(Flash Attention),另外一个是分块选择这个设计仍然是有一个注入的设计bias,就是

与任务相关的信息会相对密集的存在于长序列的少量片段中。

这个假设在时序预测的层面上是容易站得住脚,但在文本处理领域是否完全站得住脚,我打个问号。(倒不是说这个一定不对,是我也没把常见任务逐个拿起来分析过)

### Native Sparse Attention in Deep Learning Explained In the realm of deep learning, particularly within transformer models and their variants, attention mechanisms play a pivotal role by enabling models to focus on specific parts of input data when making predictions or generating outputs. Traditional full attention mechanisms compute relationships between all pairs of positions in sequences, leading to quadratic complexity with respect to sequence length. Native sparse attention addresses this issue through selective computation of attention scores only for certain predefined patterns or sparsity structures rather than every possible pair of tokens[^1]. This approach significantly reduces computational requirements while maintaining performance levels comparable to those achieved using dense attention schemes. #### Key Characteristics - **Sparsity Structures**: Various methods define different types of sparsity such as local windows where each token attends primarily to nearby elements along with occasional global connections that allow distant dependencies. - **Efficiency Gains**: By limiting interactions according to these structured patterns instead of computing pairwise relations across entire sequences, sparse attention can achieve substantial efficiency gains both in terms of memory usage and processing time required during training and inference phases. - **Performance Preservation**: Despite reducing computations involved compared to standard approaches, research indicates that well-designed sparse architectures do not compromise much—if at all—on predictive accuracy relative to fully connected counterparts under many circumstances[^2]. ```python import torch from transformers import LongformerModel model = LongformerModel.from_pretrained('allenai/longformer-base-4096') input_ids = torch.randint(0, 30000, size=(1, 4096)) # Example input tensor outputs = model(input_ids=input_ids) last_hidden_state = outputs.last_hidden_state ``` This code snippet demonstrates how one might use an implementation supporting native sparse attention from Hugging Face's Transformers library specifically designed for handling longer texts efficiently via long-range dependency modeling without excessive resource consumption.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值