KV Cache原理详解 + 代码理解

基本背景:

        KV Cache(Key-Value缓存)主要用于 加速自回归模型(如Transformer)的序列生成,解决以下核心问题:

  • 重复计算:传统自回归生成时,每次预测新token都需要重新计算所有历史token的Key和Value,计算成本随序列长度平方级增长(O(n²))。

  • 内存瓶颈:长序列生成时,反复投影历史token的特征矩阵会占用大量显存带宽。

        KV Cache通过缓存历史token的中间计算结果,将复杂度降至 O(n),成为GPT、LLaMA等大模型生成文本/语音的核心优化技术。

原理讲解:

        自注意力计算的公式为:

  • Q (Query):代表当前需要计算的位置(即新生成的token),每次解码时唯一变化的部分。

  • K (Key)/V (Value):代表历史token的上下文信息,需要被重复利用。

       基于这个特性,我们可以考虑缓存K、V而避免重复计算增加效率。

        由于 Decoder 中一般会有掩码矩阵,因此Q往往是个下三角矩阵,QK^{T}计算公式如下:

        

        可以看到,结果矩阵的第 k 行只用到了矩阵 X  的 第 k  个行向量。所以 X 不需要进行全部的矩阵乘法,每一步只取第 k 个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
        在没有 KV Cache 的情况下,如果要计算第 m+1 行,需要重新计算前 m 行,但是显然这样会造成大量的重复运算,因此我们可以保存前 m 行的结果,而只计算第 m+1 行即可。

        例如

        在计算Att2时已经保存了K1、K2、V1、V2,这样在计算Att3时就可以直接使用而无需充型计算

        

代码实现

def decode_next_token(
        self,
        x: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        attn_mask: torch.Tensor = None,
        torch_sdpa: bool = True,
    ):
        # Q、K、V计算
        q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)

        # KV cache拼接
        k_cache = torch.cat([k_cache, k], dim=1)
        v_cache = torch.cat([v_cache, v], dim=1)

        batch_size = q.shape[0]
        q_len = q.shape[1]
        kv_len = k_cache.shape[1]

        # Q、K、V准备
        q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
        k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
        v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)

        # 注意力计算
        if torch_sdpa:
            attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
        else:
            attn = scaled_dot_product_attention(q, k, v, attn_mask)

        attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
        attn = F.linear(attn, self.out_w, self.out_b)

        x = x + attn
        x = F.layer_norm(
            x,
            [self.hidden_dim],
            self.norm_w1,
            self.norm_b1,
            self.norm_eps1,
        )
        x = x + self.mlp.forward(x)
        x = F.layer_norm(
            x,
            [self.hidden_dim],
            self.norm_w2,
            self.norm_b2,
            self.norm_eps2,
        )
        return x, k_cache, v_cache

参考文章:

        https://blog.csdn.net/weixin_43799388/article/details/142164166

KV Cache(Key-Value Cache)是现代深度学习模型中,尤其是在Transformer架构中用于加速推理过程的关键技术之一。它通过缓存先前计算得到的Key(K)和Value(V)向量,避免在每次推理步骤中重复计算这些向量,从而显著提高效率。 在Transformer模型中,每个解码步骤都会生成新的Query(Q)向量,并需要与之前所有步骤中生成的Key(K)和Value(V)向量进行Attention计算。由于这种机制,随着生成序列的增长,计算量会线性增加。KV Cache的作用就是存储这些历史的K和V向量,使得在后续步骤中可以直接使用缓存中的值,而不需要重新计算。 KV Cache通常分为两种类型:逻辑KV Cache(Logical KV Cache)和物理KV Cache(Physical KV Cache)。逻辑KV Cache是模型在推理过程中逻辑上使用的K和V向量,而物理KV Cache则是实际存储在内存中的结构。逻辑KV Cache和物理KV Cache之间的映射关系可以通过一些算法来管理,例如PagedAttention算法[^1]。这种算法允许逻辑KV Cache块以非连续的方式映射到物理KV Cache块,类似于操作系统的虚拟内存管理机制。这种映射方式可以更灵活地利用内存空间,减少内存碎片,并提高内存利用率。 此外,为了进一步优化KV Cache的使用,一些压缩策略也被提出,例如YOCO和CLA。YOCO方法在模型的前一半层使用self-attention生成KV,而在后一半层使用cross-attention,不再生成新的KV,而是复用前面层的结果。CLA方法则不固定生成KV的层数,而是将生成KV的层交替分布在模型的不同深度,邻近层复用附近层生成的KV Cache进行Attention计算[^2]。 这些策略和算法共同作用,使得KV Cache成为一种高效管理Transformer模型中注意力机制所需资源的技术。 ```python # 示例代码KV Cache的基本概念 class KVCache: def __init__(self): self.cache = {} def add(self, key, value): self.cache[key] = value def get(self, key): return self.cache.get(key, None) # 初始化KV Cache kv_cache = KVCache() # 添加Key和Value到Cache kv_cache.add('key1', 'value1') # 从Cache获取Value print(kv_cache.get('key1')) # 输出: value1 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值