KV cache is a mechanism to cache repeated calculations during model inference. During inference, for each new token generation, we use a new set of key, value and query vectors. But given the fact that that the existing tokens stay the same, the vectors associated with their values do not need to be recalculated. The only thing new is the vectors associated with the new token, or more precisely, the new part of the existing vectors that have grown in size. Therefore, we cache these intermediate key and value vectors to accelerate inference speed.

Proof

In this section, we show why we can cache key and value vectors in causal attention. Consider the following example:

import torch
import torch.nn as nn

First, we setup a simple causal attention mechanism implemented in a previous post. The input has 6 elements.

torch.manual_seed(123)
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55],
    ]
)
key_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
query_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))
value_embeddings = nn.Parameter(torch.rand(inputs.shape[1], inputs.shape[1]))

We write a function to calculate the context vector:

def calculate_context_vector(
    inputs, key_embeddings, query_embeddings, value_embeddings
):
    keys = inputs @ key_embeddings
    queries = inputs @ query_embeddings
    values = inputs @ value_embeddings
    attn_scores = queries @ keys.T
    context_length = attn_scores.shape[0]
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
    attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
    context_vec = attn_weights @ values
    return context_vec

We calculate the context vector for all the input elements

ctxt = calculate_context_vector(
    inputs, key_embeddings, query_embeddings, value_embeddings
)
ctxt
tensor([[0.4976, 0.9655, 0.7614],
        [0.7674, 1.2199, 1.2528],
        [0.8186, 1.2667, 1.3497],
        [0.7324, 1.1287, 1.2029],
        [0.6963, 1.0718, 1.1713],
        [0.6824, 1.0370, 1.1307]], grad_fn=<MmBackward0>)

Now we add a new token/row to the input vector

torch.manual_seed(123)
# assume now we have a new token
new_row = torch.rand(1, inputs.shape[1])
new_inputs = torch.cat([inputs, new_row], dim=0)
new_inputs
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500],
        [0.2961, 0.5166, 0.2517]])

We calculate a new context vector based on the new input and we use torch.allclose to see if the new context vector share the first 6 rows with the previous context vector:

new_ctxt = calculate_context_vector(
    new_inputs, key_embeddings, query_embeddings, value_embeddings
)
print(new_ctxt, torch.allclose(ctxt, new_ctxt[:-1]), sep="\n")
tensor([[0.4976, 0.9655, 0.7614],
        [0.7674, 1.2199, 1.2528],
        [0.8186, 1.2667, 1.3497],
        [0.7324, 1.1287, 1.2029],
        [0.6963, 1.0718, 1.1713],
        [0.6824, 1.0370, 1.1307],
        [0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True

So the new context vector only adds a single new row, given a new input element in the input sequence. We can leverage this fact to save lots of calculations that are repeated in calculate_context_vector. More specifically, to find the new addition of the context vector, its last row, we only need to perform the same procedure, now only on the last row of the query vector new_queries[-1:]:

new_keys = new_inputs @ key_embeddings
new_values = new_inputs @ value_embeddings
new_queries = new_inputs @ query_embeddings
 
attn_scores = new_queries[-1:] @ new_keys.T
attn_weights = torch.softmax(attn_scores / new_keys.shape[-1] ** 0.5, dim=-1)
context_vec_row = attn_weights @ new_values
context_vec_row
print(context_vec_row, torch.allclose(new_ctxt[-1], context_vec_row), sep="\n")
tensor([[0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
True

As shown above, to get the new context vector row, we only need the last row of the query vector, and full key and value vectors. But the key and value vectors are not entirely new either. Only the last row of these new vectors are new, the rest are the same as previous iteration of these vectors:

keys = inputs @ key_embeddings
values = inputs @ value_embeddings
print(torch.allclose(new_keys[:-1], keys))
print(torch.allclose(new_values[:-1], values))
 
True
True

To see why this is the case, we use simple linear algebra to illustrate this. Given the current input , the key matrix is .

When a new token arrives:

So is just with one new row appended. The same holds for .

This forms the basis of kv cache.

Implement KV Cache

Given the fact that key and value vectors stay invariant during decoding, we can save already calculated k, v vectors and reuse them. To implement kv cache, it is useful to view the inference as two separate processes, prefill and decode. This is called prefill-decoding disaggregation. The prefill stage is about processing the user query, which stays invariant for each additional output token. The decode stage is responsible for generate response one token at a time and add each additional token to the kv cache.

In the prefill stage, we calculate the k, v vectors as before, since the prompt tokens (input) are already given, we can calculate their key and value vectors once and for all and save them in kv cache.

kv_cache = {"keys": None, "values": None}
torch.manual_seed(123)
new_tokens = torch.rand(4, inputs.shape[1])
print("new_tokens", new_tokens)
 
 
def prefill(inputs, key_embeddings, query_embeddings, value_embeddings):
    keys = inputs @ key_embeddings
    values = inputs @ value_embeddings
    kv_cache["keys"] = keys
    kv_cache["values"] = values
 
 
prefill(inputs, key_embeddings, query_embeddings, value_embeddings)
print(
    "Prefill complete. Cached keys shape:",
    kv_cache["keys"].shape,
    "values shape",
    kv_cache["values"].shape,
)
 
new_tokens tensor([[0.2961, 0.5166, 0.2517],
        [0.6886, 0.0740, 0.8665],
        [0.1366, 0.1025, 0.1841],
        [0.7264, 0.3153, 0.6871]])
Prefill complete. Cached keys shape: torch.Size([6, 3]) values shape torch.Size([6, 3])

In the decoding stage, we reuse the key and value vectors as before, and we no longer need the existing query vectors since they are no longer needed for calculating the context vectors for new tokens. We only need to calculate query vectors for each new token.

For each new token, we also need to calculate new key and value vectors. But instead of calculating them from ground up, we only calculate the new addition of the key and value vectors, which is just an additional row corresponding to the new token. After we calculate the new row, we add them to our kv cache.

To simplify things, we omit the process of generating a new token from context vectors. We calculate context vectors and then compare it with the full calculation path we used above to illustrate how kv cache works.

pd_context_vec = torch.empty(0, 3)
 
 
def decode(key_embeddings, value_embeddings, query_embeddings, new_tokens):
    global pd_context_vec
    for i in range(4):
        new_token = new_tokens[i : i + 1, :]
        print("new_token", new_token)
        query = new_token @ query_embeddings
        # update kv cache
        key_row = new_token @ key_embeddings
        value_row = new_token @ value_embeddings
        kv_cache["keys"] = torch.cat([kv_cache["keys"], key_row], dim=0)
        kv_cache["values"] = torch.cat([kv_cache["values"], value_row], dim=0)
 
        attn_scores = query @ kv_cache["keys"].T
        attn_weights = torch.softmax(
            attn_scores / kv_cache["keys"].shape[-1] ** 0.5, dim=1
        )
 
        context_vec = attn_weights @ kv_cache["values"]
        pd_context_vec = torch.cat((pd_context_vec, context_vec), dim=0)
        # we do not use this context_vec since we assume the new token is generated by the context vector
        print("context_vec", context_vec)
 
 
decode(key_embeddings, value_embeddings, query_embeddings, new_tokens)
print(
    "Final kv cache keys shape:",
    kv_cache["keys"].shape,
    "values shape",
    kv_cache["values"].shape,
)
 
new_token tensor([[0.2961, 0.5166, 0.2517]])
context_vec tensor([[0.6538, 0.9875, 1.0863]], grad_fn=<MmBackward0>)
new_token tensor([[0.6886, 0.0740, 0.8665]])
context_vec tensor([[0.6674, 1.0268, 1.1071]], grad_fn=<MmBackward0>)
new_token tensor([[0.1366, 0.1025, 0.1841]])
context_vec tensor([[0.5850, 0.9149, 0.9716]], grad_fn=<MmBackward0>)
new_token tensor([[0.7264, 0.3153, 0.6871]])
context_vec tensor([[0.6361, 0.9934, 1.0588]], grad_fn=<MmBackward0>)
Final kv cache keys shape: torch.Size([10, 3]) values shape torch.Size([10, 3])
# verify the final context vector is the same across kv-cached version and full-calculation route
all_tokens = torch.cat([inputs, new_tokens], dim=0)
ctxt_naive = calculate_context_vector(
    all_tokens, key_embeddings, query_embeddings, value_embeddings
)
print(
    ctxt_naive[6:],
    pd_context_vec,
    torch.allclose(pd_context_vec, ctxt_naive[6:]),
    sep="\n",
)
tensor([[0.6538, 0.9875, 1.0863],
        [0.6674, 1.0268, 1.1071],
        [0.5850, 0.9149, 0.9716],
        [0.6361, 0.9934, 1.0588]], grad_fn=<SliceBackward0>)
tensor([[0.6538, 0.9875, 1.0863],
        [0.6674, 1.0268, 1.1071],
        [0.5850, 0.9149, 0.9716],
        [0.6361, 0.9934, 1.0588]], grad_fn=<CatBackward0>)
True