基本背景:
KV Cache(Key-Value缓存)主要用于 加速自回归模型(如Transformer)的序列生成,解决以下核心问题:
-
重复计算:传统自回归生成时,每次预测新token都需要重新计算所有历史token的Key和Value,计算成本随序列长度平方级增长(O(n²))。
-
内存瓶颈:长序列生成时,反复投影历史token的特征矩阵会占用大量显存带宽。
KV Cache通过缓存历史token的中间计算结果,将复杂度降至 O(n),成为GPT、LLaMA等大模型生成文本/语音的核心优化技术。
原理讲解:
自注意力计算的公式为:
-
Q (Query):代表当前需要计算的位置(即新生成的token),每次解码时唯一变化的部分。
-
K (Key)/V (Value):代表历史token的上下文信息,需要被重复利用。
基于这个特性,我们可以考虑缓存K、V而避免重复计算增加效率。
由于 Decoder 中一般会有掩码矩阵,因此Q往往是个下三角矩阵,计算公式如下:
可以看到,结果矩阵的第 k 行只用到了矩阵 X 的 第 k 个行向量。所以 X 不需要进行全部的矩阵乘法,每一步只取第 k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
在没有 KV Cache 的情况下,如果要计算第 m+1 行,需要重新计算前 m 行,但是显然这样会造成大量的重复运算,因此我们可以保存前 m 行的结果,而只计算第 m+1 行即可。
例如
在计算Att2时已经保存了K1、K2、V1、V2,这样在计算Att3时就可以直接使用而无需充型计算
代码实现
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
# Q、K、V计算
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
# KV cache拼接
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
# Q、K、V准备
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
# 注意力计算
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
参考文章: