Skip to content

🐳 Efficient Triton implementations for "Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention"

License

Notifications You must be signed in to change notification settings

winlover32/native-sparse-attention

 
 

Repository files navigation

🐳 Native Sparse Attention

arxiv

Efficient Triton implementations for Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention.

image

News

  • [2025/02/21] We support a variable number of selected blocks for queries across different positions and batches.

Usage

from native_sparse_attention.ops.parallel import parallel_nsa

B, T, H, HQ, D = 4, 2048, 4, 64, 64
block_size = 64
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
# randomly generated block indices
indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
s = torch.randint(1, S + 1, (B, T, H), device='cuda')
for b in range(B):
    for t in range(T):
        for h in range(H):
            i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
            indices[b, t, h, :len(i_i)] = i_i
indices = indices.sort(-1)[0]

parallel_nsa(
  q=q,
  k=k,
  v=v,
  indices=indices,
  s=s,
  block_size=block_size
)

# variable-length inputs are supported as well
# randomly split the sequence into N segments
N, T = 4, 2048
offsets = torch.cat([
    torch.tensor([0], dtype=torch.long),
    torch.arange(16, T)[torch.randperm(T - 1)[:N-1]],
    torch.tensor([T], dtype=torch.long)
], 0).cuda().sort()[0]
# seq-first required for inputs with variable lengths
q = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda').requires_grad_()
k = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_()
v = torch.randn((1, T, H, D), dtype=dtype, device='cuda').requires_grad_()

indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda')
s = torch.randint(1, S + 1, (B, T, H), device='cuda')
seq_indices = prepare_token_indices(offsets).tolist()
for i in range(T):
    _, t = seq_indices[i]
    for h in range(H):
        i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
        indices[0, i, h, :len(i_i)] = i_i
indices = indices.sort(-1)[0]
parallel_nsa(
    q=q,
    k=k,
    v=v,
    indices=indices,
    s=s,
    block_size=block_size,
    cu_seqlens=offsets
)

Benchmarks

Performance:
         T        nsa     nsa_bwd      flash   flash_bwd
0    128.0   0.116224    0.561968   0.019552    0.123888
1    256.0   0.216896    0.963808   0.041472    0.223840
2    512.0   0.414688    1.951680   0.093168    0.486176
3   1024.0   0.813952    4.039584   0.260000    1.252896
4   2048.0   1.672784    9.081648   0.855856    3.794176
5   4096.0   3.518624   19.852303   3.196768   12.965824
6   8192.0   7.535328   43.620705  12.336976   47.652878
7  16384.0  16.107521  102.203011  48.110847  186.464386
image

Citations

@inproceedings{Yuan2025NativeSA,
    title   = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
    author  = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
    year    = {2025},
    url     = {https://api.semanticscholar.org/CorpusID:276408911}
}

About

🐳 Efficient Triton implementations for "Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%