KV Cache in Transformer - Yigeng’s Blog

KV Cache in Transformer

yigeng 2024-05-25 {Machine Learning} [Transformer, Python]

最近在研究Transformer,经常碰到KV-Cache,每次都要去查一下才能回想起来这个东西。但稍微研究了一下,觉得kv-cache从数学上来说是非常显然的,在这里总结一下我对kv-cache的理解。

KV-Cache源自于Transformer,全称是Key-Value Cache,也就是对Key和Value的缓存。

Transformer相对于CNN最大的特点就是引入了Attention计算,其计算公式如下: $$ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V \tag{1} $$ 其中$Q,K,V$分别代表Query、Key和Value矩阵,$d_k$代表Scaling factor,属于实数。所谓KV-Cache便是在Attention计算中,通过对Key和Value进行缓存,从而减小不必要的开销,这里“不必要”指由于我们已经缓存了部分Key和Value,因此不需要再进行重复计算以得到这部分Key和Value,使用之前计算好了的即可。

那么为什么说后续的Attention计算中会用到之前的Key和Value呢?这就要讲一下Transformer是如何进行推理了。

auto-regressive generation of the decoder In the auto-regressive generation of the decoder, given an input the model predicts the next token, and then taking the combined input in the next step the next prediction is made. (Image source: https://jalammar.github.io/illustrated-gpt2/).

如上图所示GPT-2使用的是Transformer的decoder架构,这类模型在推理时采取的是auto-regressive自回归式的风格,具体的说,在第$i$个round时,模型输出1个token,例如上图中的“robot”。

到了第$i+1$个$\mathrm{round}$模型会继续推理,但这时模型的输入会发生变化,它会将上个$\mathrm{round}$预测得到的token,“robot” append到第$i$个$\mathrm{round}$的输入recite ... A的后面,作为第$i+1$个$\mathrm{round}$模型输入。

这里可以稍微看下Transformer推理的源码,auto-regressive的体现便是在torch.cat函数。

# Generate the translation word by word
while decoder_input.size(1) < seq_len:
    # build mask for target and calculate output
    decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int).type_as(source_mask).to(device)
    out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

    # project next token
    prob = model.project(out[:, -1])
    _, next_word = torch.max(prob, dim=1)
    decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

    # print the translated word
    print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')

    # break if we predict the end of sentence token
    if next_word == tokenizer_tgt.token_to_id('[EOS]'):
        break

一句话总结,auto-regressive类模型会将预测的结果Concatenate到输入末尾作为新的输入,然后继续预测。

了解了Transformer是如何推理的,我们再回到Attention的计算。

假设从$\mathrm{round_1}$开始模型输入是1个token,在具体的实现上,我们会用Vector向量来表示1个token,正如我们在cv中用矩阵表示图像那样,我们记这个token输入向量为$x_1$,为了方便后续数学表达,当我们定义一个向量时,默认它为行向量,区别于传统线代中默认向量为列向量。

这里顺便说一下,在Transformer的代码实现上,Attention机制中的$Q_w,K_w,V_w
$是一组可学习的权重参数,用nn.Linear来表示的线性层。

在$\mathrm{round_1}$推理过程中,我们将$x_1$分别和$Q_w,K_w,V_w
$相乘得到 $q_1,k_1,v_1$,然后带入到等式1中完成Attention的计算。整个模型forward结束,会得到模型输出的1个token,我们把它记为$x_2$。

在$\mathrm{round_2}$时,我们将$x_1, x_2$组合作为输入,和$Q_w,K_w,V_w
$做矩阵乘得到$q_1, q_2$,$k_1,k_2$以及$v_1 v_2
$。

在$\mathrm{round_n}$时,我们手里有$x_1,x_2,…x_n$,记$X=\begin{pmatrix}x_1\\x_2\\ \vdots\\x_n\end{pmatrix}$。我们将$X$分别与$Q_w,K_w,V_w
$相乘得到$Q,K,V$

有了$Q,K,V$,我们来看一下Attention的计算,注意Attention计算的关键在于QKV矩阵乘,softmax对每一行元素做归一化,Scaling factor对矩阵中每个元素的value做放缩,总的来说softmax和Scaling factor都只是对矩阵中元素值做放缩。因此,我们忽略softmax和Scaling factor并不会对KV-Cache的理解有影响。那么,Attention的计算展开成向量形式如下: $$ \mathrm{Attention}(Q,K,V) \approx \begin{pmatrix}q_1 \\q_2 \\ \vdots\\q_n\end{pmatrix}\begin{pmatrix}k^T_1, k^T_2,\cdots, k^T_n\end{pmatrix}\begin{pmatrix}v_1 \\v_2 \ \vdots\\v_n\end{pmatrix}\\ =\begin{pmatrix}q_1k^T_1&&&\\q_2k^T_1&q_2k^T_2&&\\ \vdots&\vdots&\ddots&\\q_nk^T_1&q_nk^T_2&\cdots&q_nk^T_n\end{pmatrix}\begin{pmatrix}v_1 \\v_2 \\ \vdots\\v_n\end{pmatrix}\\ =\begin{pmatrix}q_1k^T_1v_1 \\q_2k^T_1v_1 + q_2k^T_2v_2 \\ \vdots\\q_nk^T_1v_1 + q_nk^T_2v_2+\cdots+q_nk^T_nv_n\end{pmatrix} \tag{2} $$ 注意在上面第二步推导中,$QK^T$是一个下三角矩阵,因为在$\mathrm{round_i}$时,我们仅有$q_j,k_j,v_j$其中$j \leq i$。令$A$表示式子2中的最终结果,$A$的第$i$个行向量$A_i=\sum_{j}^{i}q_ik_j^Tv_j$。

看到这里,不知道你有没有一种恍然大悟的感觉😉。由于我们在$\mathrm{round_{j < i}}$的过程中,缓存了$k_j,v_j$,那么在$\mathrm{round_i}$时我们只需要再计算一下$k_i,v_i$,复用缓存的key和value便能完成整个Attention的计算。

这里可以再多扯一下,KV-Cache的唯一作用便是避免冗余计算。在Transformer的训练过程中是不存在KV-cache的,因为在训练的过程中,我们将一个句子的前n-1个token喂给模型,将预测得到的token和ground truth做cross-entropy,也就是说在训练过程中是one-shot风格而不是自回归式的,训练代码为证。

for i, batch in enumerate(iterator):
    src = batch.src
    trg = batch.trg

    optimizer.zero_grad()
    output = model(src, trg[:, :-1])
    output_reshape = output.contiguous().view(-1, output.shape[-1])
    trg = trg[:, 1:].contiguous().view(-1)

    loss = criterion(output_reshape, trg)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()

KV-cahe仅存在于推理阶段,而且只存在于decoder中。下图是Transformer结构,在右侧decoder中,Multi-Head Attention模块的Key和Value来自于Encoder的输出,是一次性全部生成好了的,因此我们通常是在让decoder预测前在内存中缓存好encoder的output,作为Key和Value用于cross-Attention的计算。而decoder中的Masked Multi-Head Attention模块便是上文所讲,在逐token的生成中不断缓存key-value。

The full model architecture of the transformer The full model architecture of the transformer. (Image source: Fig 1 & 2 in Vaswani, et al., 2017.)

References

  1. Transformers KV Caching Explained
  2. pytorch-transformer
  3. Ashish Vaswani, et al. “Attention is all you need.” NIPS 2017.
  4. Transformer: PyTorch Implementation of “Attention Is All You Need”
comments powered by Disqus