Machine Learning

KV Cache

The one trick that turns quadratic LLM generation into linear — and eats all your GPU memory doing it

A KV cache stores the key and value tensors of every past token so a transformer can generate each new token in O(n) work instead of recomputing attention over the whole sequence at O(n²) — the single optimization that makes autoregressive LLM inference affordable.

  • Per-token work, no cacheO(n²)
  • Per-token work, cachedO(n)
  • Full-sequence generationO(n³) → O(n²)
  • Cache size per token2 · L · h · d values
  • Decode bottleneckMemory bandwidth

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

The redundant work a KV cache kills

An autoregressive language model generates text one token at a time. To produce token number t, it runs the whole prompt-so-far through the network, reads the final logits, samples a token, appends it, and repeats. The naive way to do this is to re-feed the entire growing sequence on every step. Generate a 1,000-token reply and the model processes the prompt on step 1, the prompt plus one token on step 2, the prompt plus two tokens on step 3 — and so on, a thousand times over. The first prompt token gets re-encoded a thousand times.

Almost all of that is wasted, and the reason is the structure of self-attention combined with the causal mask. In a decoder transformer, token t can only attend to tokens at positions ≤ t. Crucially, the key and value vectors that token 5 produces depend only on token 5's input — they do not change when token 6, 7, or 8 arrive later. They are computed once and then read, unchanged, by every future token's attention. So recomputing them on every step is pure redundancy.

The fix is exactly what the name says: cache the keys and values. After processing each token, stash its per-layer K and V tensors in a growing buffer. On the next step you feed in only the single newest token, compute its query, and let it attend against the entire cached history. Each new token now costs one row of fresh work instead of re-deriving the whole sequence. That is the KV cache.

The precise mechanism, layer by layer

Inside one attention layer, each token's hidden state x is projected into three vectors by learned weight matrices: query q = x·W_Q, key k = x·W_K, value v = x·W_V. Attention for a query at position t is:

attn(q_t) = softmax( q_t · K^T / √d ) · V

  where K = [k_1; k_2; …; k_t]   (all keys up to t)
        V = [v_1; v_2; …; v_t]   (all values up to t)
        d = head dimension

Read that carefully: the only term that depends on the current step is q_t. The matrices K and V are just the stacked keys and values of every token so far. Once k_5 and v_5 are computed they are immutable.

So the cached decode step is:

  1. Take the single new token's hidden state, project to q_t, k_t, v_t.
  2. Append k_t to the cached K and v_t to the cached V (one new row each).
  3. Compute softmax(q_t · K^T / √d) · V — a single query against the full cache.
  4. Pass the result up to the next layer, which does the same against its own cache.

The cache is per-layer and per-head: every one of the model's L attention layers keeps its own K and V buffers, because each layer projects with different weights. So the model holds L separate caches, all growing one row per generated token in lockstep.

The complexity that actually changes

This is the heart of why the cache matters. Let n be the current sequence length and treat the model dimension as a constant.

  • No cache, per step. You re-encode all n tokens. The attention score matrix is n × n, so a single forward pass is O(n²).
  • With cache, per step. You encode one new token and attend it against n cached keys — a single 1 × n score row. That is O(n) per step.

Now sum over a full generation of n tokens:

  • No cache, whole sequence. Step t costs O(t²), so the total is Σ t² = O(n³).
  • With cache, whole sequence. Step t costs O(t), so the total is Σ t = O(n²).

The cache removes one full factor of n from generation. For a 2,000-token output that is the difference between roughly 8 billion and 4 million units of attention work per layer — a thousand-fold reduction. The catch, which the rest of this article keeps returning to, is that you pay for that speed in memory: the cache grows linearly with the sequence and is read in full on every single step.

Prefill vs decode: two different machines

Cached inference splits into two phases with opposite performance profiles.

Prefill. The prompt is known up front, so all of its tokens go through the model in one parallel pass. This computes K and V for every prompt token and fills the cache in a single batched matmul. Prefill is compute-bound — the GPU's tensor cores are saturated — and it is cheap per token because of the parallelism.

Decode. Now the model emits one token per step, and each step depends on the previous output, so the steps are inherently sequential. Each step does very little arithmetic (one query row) but must read the entire cache from GPU memory. Decode is therefore memory-bandwidth-bound: the tensor cores sit mostly idle while the GPU streams gigabytes of cached K/V through. For long replies, decode dominates wall-clock latency, and the whole game of fast LLM serving becomes a game of moving the KV cache around efficiently.

KV-cache strategies compared

No cacheDense MHA cacheMQA cacheGQA cachePagedAttention
Per-step computeO(n²)O(n)O(n)O(n)O(n)
K/V heads storedh (all)1g groupsg groups
Cache size vs MHA01/hg/hg/h, no fragmentation
Memory layoutcontiguous slabcontiguous slabcontiguous slabfixed-size pages
Wasted memorynone60–80% (over-reserved)60–80%60–80%<4%
Quality impactnonenonesmall dropnegligiblenone
Used by(teaching only)GPT-2, original BERT-stylePaLM, FalconLlama-2-70B, MistralvLLM, TGI, TensorRT-LLM

The progression down this table is the history of LLM serving over the last few years: first get the algorithm linear (any cache at all), then shrink what you store per token (MQA/GQA), then stop wasting the memory you do allocate (PagedAttention). They compose — production stacks use GQA and paging together.

What the numbers actually say

  • 512 KB per token for Llama-2-7B. The per-token cache is 2 × L × n_kv_heads × head_dim × bytes_per_value. For 32 layers, 32 heads, head_dim 128, fp16: 2 × 32 × 32 × 128 × 2 ≈ 512 KB. A single 4,096-token context is about 2 GB of cache, on top of ~13 GB of weights.
  • The cache, not the weights, caps your batch size. On an 80 GB A100 running Llama-2-7B, weights eat ~13 GB; the remaining ~67 GB of cache budget is what limits how many simultaneous sequences you can serve. Double the context length and you halve the concurrent request count.
  • MQA shrinks the cache up to 32×. Storing one K/V head instead of 32 cuts that 512 KB/token down to ~16 KB/token, turning a 2 GB context into 64 MB. GQA with 8 groups gives a 4× reduction — most of the win, almost none of the quality loss.
  • PagedAttention recovered ~60–80% wasted memory and lifted serving throughput roughly 2–4× over the previous best systems (vLLM, 2023), mainly by packing far more concurrent sequences into the same GPU.
  • A thousand-fold compute cut on a 2K generation. Naive whole-sequence cost is O(n³); cached is O(n²). At n = 2,000 that is the gap between ~8 × 10⁹ and ~4 × 10⁶ attention operations per layer.

JavaScript implementation

A single-head decoder layer with an explicit growing cache. The arithmetic is real — dot products, scaled softmax, value mixing — just at toy dimensions so it runs in a console.

// Toy single-head causal attention with a KV cache.
const D = 4; // head dimension

const dot = (a, b) => a.reduce((s, x, i) => s + x * b[i], 0);
const matvec = (W, x) => W.map(row => dot(row, x));   // W is D×D

function softmax(scores) {
  const m = Math.max(...scores);                      // numerical stability
  const ex = scores.map(s => Math.exp(s - m));
  const sum = ex.reduce((a, b) => a + b, 0);
  return ex.map(e => e / sum);
}

class KVCache {
  constructor() { this.K = []; this.V = []; }         // each: array of D-vectors
  get length() { return this.K.length; }
  append(k, v) { this.K.push(k); this.V.push(v); }
}

// Process ONE new token. The cache already holds every prior token's k, v.
function decodeStep(x, Wq, Wk, Wv, cache) {
  const q = matvec(Wq, x);
  const k = matvec(Wk, x);
  const v = matvec(Wv, x);
  cache.append(k, v);                                 // immutable once stored

  const scale = 1 / Math.sqrt(D);
  const scores = cache.K.map(ki => dot(q, ki) * scale); // 1 × n, not n × n
  const weights = softmax(scores);                    // attend over all history

  // weighted sum of cached values
  const out = new Array(D).fill(0);
  cache.V.forEach((vi, i) => vi.forEach((c, j) => { out[j] += weights[i] * c; }));
  return out;
}

// --- usage: generate 5 steps, feeding back the previous output ---
const rand = () => Array.from({ length: D }, () => Array.from({ length: D }, Math.random));
const [Wq, Wk, Wv] = [rand(), rand(), rand()];
const cache = new KVCache();

let x = [1, 0, 0, 0];                                 // first token embedding
for (let t = 0; t < 5; t++) {
  x = decodeStep(x, Wq, Wk, Wv, cache);               // O(n) work, cache grows by 1
  console.log(`step ${t}: cache length = ${cache.length}`);
}

The two lines that matter are cache.append(k, v) and the cache.K.map(...) that scores one query against the whole buffer. Notice the score array is length n, not n × n — that single dimension reduction is the entire O(n²)→O(n) win, made concrete.

Python / PyTorch sketch

Real frameworks pass a past_key_values tuple in and a new one out, and concatenate along the sequence axis. This is the shape of what Hugging Face generate() does under the hood.

import torch, torch.nn.functional as F

class CachedAttention(torch.nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.h, self.dk = n_heads, d_model // n_heads
        self.Wq = torch.nn.Linear(d_model, d_model, bias=False)
        self.Wk = torch.nn.Linear(d_model, d_model, bias=False)
        self.Wv = torch.nn.Linear(d_model, d_model, bias=False)
        self.Wo = torch.nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, past=None):
        # x: (batch, new_len, d_model). During decode, new_len == 1.
        B, T, _ = x.shape
        shape = lambda t: t.view(B, T, self.h, self.dk).transpose(1, 2)  # (B,h,T,dk)
        q, k, v = shape(self.Wq(x)), shape(self.Wk(x)), shape(self.Wv(x))

        if past is not None:                       # append to the cache
            pk, pv = past
            k = torch.cat([pk, k], dim=2)          # grow along the seq axis
            v = torch.cat([pv, v], dim=2)
        present = (k, v)                           # hand the new cache back

        scores = (q @ k.transpose(-2, -1)) / self.dk ** 0.5   # (B,h,T,n)
        # Causal mask only needed during prefill (T>1); decode's single q sees all.
        attn = F.softmax(scores, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
        return self.Wo(out), present


# --- greedy decode loop ---
def generate(layer, embed, x0, n_steps):
    past, tok = None, x0
    for _ in range(n_steps):
        out, past = layer(embed(tok), past)        # feed ONE token, cache grows
        tok = out[:, -1:].argmax(-1, keepdim=True) # next-token id (toy)
    return past

The single most common bug lives in this loop and not in the math: during decode you feed only the newest token, so its position must be the current cache length, not zero. If you reset rotary (RoPE) or absolute positional offsets to 0 on every step, the cached keys and the fresh query disagree about where they are in the sequence and the output silently degrades into gibberish.

Variants worth knowing

Multi-Query Attention (MQA). Use the full set of query heads but a single shared key/value head. The cache shrinks by the head count (up to 32×), which is enormous — the cache is the binding constraint on serving. Introduced by Shazeer (2019); used by PaLM and Falcon. The cost is a small, measurable quality drop.

Grouped-Query Attention (GQA). The middle ground: partition query heads into g groups that each share one K/V head. With g = 8 you get a 4× cache reduction at near-zero quality loss. This is now the default — Llama-2-70B, Llama-3, and Mistral all ship GQA.

PagedAttention. A systems variant rather than a model one. Instead of reserving a contiguous slab per request (which over-allocates for the worst-case length), store the cache in fixed-size blocks like OS virtual-memory pages, allocate on demand, and let multiple requests share identical prompt-prefix blocks via copy-on-write. The basis of vLLM (2023).

Sliding-window / streaming caches. For very long contexts you can drop the oldest entries and keep only the last W tokens (Mistral's sliding window) or keep a few "attention sink" tokens plus a recent window (StreamingLLM). The cache becomes bounded instead of linear, trading some long-range recall for constant memory.

Cache quantization. Store K and V in int8 or even 4-bit instead of fp16, halving or quartering the cache at a small accuracy cost — independent of, and stackable with, MQA/GQA and paging.

Common bugs and edge cases

  • Stale positional offset. The number-one bug: during decode the new token's position must equal the current cache length. Resetting RoPE/positional indices to 0 each step corrupts attention silently — no crash, just degrading text.
  • Mismatched cache length and mask. If you keep a causal mask sized for the prompt but the cache has grown, the mask either truncates valid history or indexes out of bounds.
  • Forgetting the cache is per-layer. Every attention layer needs its own K/V buffer with its own projection weights; sharing one buffer across layers produces nonsense.
  • Cache OOM at long context. The cache grows linearly and is read every step, so a long prompt or runaway generation can exhaust GPU memory mid-decode. Bound it with a max-length, eviction, or a sliding window.
  • Batched sequences of different lengths. When you batch requests, their caches grow at different rates; left-padding plus careful masking (or a paged layout) is required or short sequences attend to garbage.
  • Beam search and the cache. Each beam needs its own cache state, and when beams are reordered or pruned the caches must be reordered to match — a frequent source of subtle decoding bugs.
  • Reusing a cache across prompts. The cache is specific to one sequence's history. Starting a new prompt without clearing (or correctly prefix-sharing) it leaks one conversation into another.

Frequently asked questions

Why do you cache keys and values but not queries?

Because of causal masking, each token's query only ever attends to keys and values at its own position and earlier. A past token's query is never re-used once that token has been generated, so caching it would waste memory. But its key and value are read by every future token's attention, so those are exactly what you keep.

How big does a KV cache actually get?

Per token, the cache holds 2 × n_layers × n_kv_heads × head_dim values. For Llama-2-7B (32 layers, 32 heads, head_dim 128) in fp16 that is 2 × 32 × 32 × 128 × 2 bytes ≈ 512 KB per token. A 4,096-token context costs about 2 GB — for a single sequence, on top of the 13 GB of weights.

What's the difference between prefill and decode?

Prefill processes the whole prompt in one parallel pass, computing and storing K and V for every prompt token — it is compute-bound and fast per token. Decode generates one token at a time, each step reading the entire cache and appending one new K/V row — it is memory-bandwidth-bound and dominates latency for long generations.

Does the KV cache change the model's output?

No. It is a pure optimization: the cached K and V are bit-identical to what a from-scratch recomputation would produce, so the logits and sampled tokens are unchanged. The only risk is implementation bugs — most commonly forgetting to advance positional encodings or RoPE offsets to match the cache length.

Why is multi-query and grouped-query attention so popular now?

The KV cache, not the weights, is what limits how many requests you can batch. Multi-query attention shares one K/V head across all query heads, shrinking the cache by the head count — up to 32× — at a small quality cost. Grouped-query attention is the middle ground used by Llama-2-70B and Mistral, sharing K/V across small groups of query heads.

What is PagedAttention and why does it matter?

A naive cache reserves a contiguous slab for the maximum context length per request, wasting 60–80% of it to internal fragmentation. PagedAttention (vLLM, 2023) stores the cache in fixed-size blocks like operating-system virtual memory, allocating blocks on demand and sharing identical prefix blocks across requests. It raised serving throughput by roughly 2–4× over the previous best systems.