Machine Learning

Transformer Attention

Compute Q·Kᵀ/√d, softmax, multiply by V — every token attends to every other in parallel

Self-attention is the operation that powers transformers — including GPT, BERT, Llama, Claude — and replaced recurrent networks for language modeling. For each input token, project to a query Q, key K, and value V. Attention output = softmax(QKᵀ/√d) × V. Multiplying Q × Kᵀ produces a matrix of pairwise relevance scores; softmax normalizes them; multiplying by V produces the weighted output. Multi-head attention runs h independent attention heads in parallel and concatenates them — typically h=12 (BERT base), 32 (GPT-3 175B). Cost: O(n²·d) for sequence length n — the quadratic bottleneck driving FlashAttention (2022, Dao et al.) and linear attention research. Introduced by Vaswani et al. 2017 ("Attention Is All You Need", 100,000+ citations).

  • Operationsoftmax(QKᵀ/√d) V
  • CostO(n²·d)
  • Heads8-32 typical
  • IntroducedVaswani et al. 2017
  • Citations100,000+
  • Linear variantsPerformer, Linformer, FlashAttention

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

Why attention matters

  • Every modern LLM. GPT-4, Claude, Gemini, Llama 3 — all are decoder-only transformers whose only sequence-mixing primitive is multi-head self-attention. The 2017 architecture changed unilaterally what state-of-the-art language modeling looked like.
  • Vision transformers. ViT (Dosovitskiy et al. 2020) split images into patches and apply the exact same self-attention. SAM, DINO v2, and most modern multimodal vision encoders use ViT backbones.
  • Multimodal models. Cross-attention enables text-conditioned image generation (Stable Diffusion uses cross-attention from latent image patches to CLIP text embeddings). VLMs like GPT-4o, Gemini, Claude with vision use shared transformer trunks over interleaved text and image tokens.
  • Parallel along the sequence. Unlike RNNs which compute step-by-step, attention computes the entire sequence's representations in parallel. This is the key practical reason transformers scale better on GPUs than LSTMs ever did.
  • Long-range dependencies. A token at position 5000 can attend directly to a token at position 1; the path length is constant. RNNs degrade with distance because gradient must flow through every intermediate step.
  • Mechanistic interpretability. Attention heads turn out to compute interpretable patterns — previous-token, induction, name-mover heads. Anthropic and others use these patterns to reverse-engineer LLM circuits.
  • Hardware-friendly. The dominant operation is matrix multiplication, which GPUs and TPUs accelerate massively. Attention spends most of its time inside cuBLAS / cuDNN matmul kernels.

The scaled dot-product formula

  1. Project. For input X of shape (n, d_model), compute Q = XW_Q, K = XW_K, V = XW_V where W_* are learned matrices of shape (d_model, d_head).
  2. Score. S = QKᵀ / √d_head — an n × n matrix of pairwise relevance scores.
  3. Mask (optional). Add −∞ to entries that should not be attended to. Causal mask: zero above the diagonal. Padding mask: zero on padding positions.
  4. Softmax. Apply softmax row-wise to get attention weights A — non-negative, rows sum to 1.
  5. Mix. Output = AV — for each query position, a convex combination of value vectors weighted by attention.
  6. Multi-head. Repeat steps 1-5 h times in parallel with separate W_Q, W_K, W_V projections; concatenate outputs; project with W_O of shape (d_model, d_model).

A short history

  • 2014. Bahdanau et al. introduce neural attention for sequence-to-sequence translation: a small alignment model picks which encoder positions a decoder step should attend to.
  • 2015. Luong et al. simplify Bahdanau's attention with global vs local variants.
  • 2017. Vaswani et al. publish "Attention Is All You Need". They drop recurrence entirely, propose scaled dot-product self-attention, multi-head, encoder-decoder transformer. Now over 100,000 citations — one of the most cited ML papers ever.
  • 2018. BERT (Devlin et al.) pretrains an encoder-only transformer with masked language modeling; GPT (Radford et al.) does the same for decoder-only with next-token prediction.
  • 2019-2022. Scaling: GPT-2, GPT-3, PaLM. Each step roughly 10× bigger; emergent capabilities at >100B parameters.
  • 2022. FlashAttention (Dao et al.) cracks the IO bottleneck; ChatGPT launches.
  • 2023-2026. Llama, Mistral, Claude, Gemini, GPT-4, GPT-4o — transformer attention remains the central primitive; sliding-window, MoE, GQA, MLA all are attention-variants.

Why multi-head works

  • Subspace specialization. Different heads learn distinct attention patterns: previous-token, syntactic-dependency, induction (Olsson et al. 2022), name-mover, copy-suppressor.
  • Parallelizable. All h heads run as one big batched matmul on GPUs; the only sequential cost is the final concatenation and output projection.
  • Better than one big head. Empirically, h heads of dimension d/h beat one head of dimension d. The Vaswani paper showed h=8 was a sweet spot; later models scaled h with model size.
  • Grouped-query attention (GQA). Llama 2/3, Mistral, Claude pack the K and V projections into fewer groups than Q. Cuts KV-cache memory and inference cost without losing much quality.
  • Multi-Latent Attention (MLA, DeepSeek 2024). Compresses K, V into a low-dimensional latent before reconstruction — another KV-cache savings.

The KV cache

  • Autoregressive decoding bottleneck. When generating tokens one at a time, recomputing K and V for the entire prefix at each step is wasteful.
  • Cache. Store K and V for all prior tokens. At step t, compute only Q_t, then attend Q_t against cached K_{1..t} and V_{1..t}.
  • Memory cost. KV cache = 2 × n_tokens × n_heads × d_head × n_layers × bytes_per_value. For Llama-70B (80 layers, 64 KV heads each of 128, fp16) at 8K context: ~20 GB — often the dominant inference memory.
  • Optimizations. GQA cuts KV heads. MLA compresses. PagedAttention (vLLM 2023) virtualizes the cache to support thousands of concurrent requests on one GPU.

FlashAttention and IO-aware kernels

  • The bottleneck. Standard attention reads/writes the n × n score and probability matrices to GPU HBM. HBM bandwidth (3 TB/s on H100) becomes the limiter, not FLOPs (989 TFLOPs).
  • FlashAttention. Dao et al. 2022. Tiles Q, K, V into chunks small enough to fit in GPU SRAM (~100 KB per SM). Computes softmax(QKᵀ) V incrementally via online-softmax. Never materializes the full n × n matrix; only writes the n × d output.
  • Performance. 2-4× faster than vanilla PyTorch attention; memory drops from O(n²) to O(n).
  • FlashAttention-2 (2023). Better work partitioning across warps; another 2× speedup.
  • FlashAttention-3 (2024). Hopper-specific (H100); uses TMA, async warpgroup matmul, and FP8.
  • Standard everywhere. Built into PyTorch's scaled_dot_product_attention; HuggingFace, vLLM, TGI all default to it.

Linear and sparse attention

  • Linformer. Project K and V to a small constant dimension k; attention becomes O(n · k).
  • Performer. Approximate softmax with kernel feature maps (FAVOR+); O(n · d) cost.
  • Reformer. LSH bucketing approximates which Q-K pairs have high similarity.
  • Longformer / BigBird. Sparse attention patterns: sliding window + global tokens.
  • Mamba (2023). Selective state-space model; linear-time sequence mixing without softmax. Competes with attention on some tasks.
  • RWKV / RetNet. Recurrent reformulations of attention with linear-time training and O(1) per-step inference.
  • Practical reality. Despite years of research, vanilla full attention with FlashAttention plus sliding-window or GQA remains dominant for production LLMs through 2026.

Positional encoding

  • Sinusoidal (original). Vaswani 2017 added fixed sin/cos position vectors to token embeddings.
  • Learned. BERT, GPT-2 use learned position embeddings — doesn't extrapolate beyond training length.
  • RoPE (rotary position embeddings). Su et al. 2021; rotates Q and K by angles depending on position. Used by Llama, Mistral, GPT-NeoX. Extrapolates better and supports YaRN/scaling for longer contexts.
  • ALiBi. Press et al. 2021; adds position-dependent bias directly to attention scores. Bloom and MPT used it.

Common misconceptions

  • "Attention is the only transformer trick." Transformers also rely on LayerNorm/RMSNorm, residual connections, position-wise MLPs (often 4× wider than d_model), and careful initialization. Removing any one usually destroys training stability.
  • "O(n²) scales forever." FlashAttention, sliding-window, and GQA together let modern models run 1M-token contexts (Gemini 1.5, Claude 3.5 with 200K, Llama 3 with 128K). Quadratic compute remains, but constants and memory are now manageable.
  • "Attention attends to one token." Softmax produces a soft-weighted average over all positions. Attention can be sharp (one position dominates) or diffuse (many positions contribute roughly equally) — both happen in practice.
  • "Attention is always interpretable." Some heads compute clean human-interpretable patterns; many do not. Attention weights alone are not a faithful explanation of model output (Jain & Wallace 2019).
  • "Cross-attention is more powerful than self-attention." Decoder-only models with self-attention have eclipsed encoder-decoder for most tasks; the simpler architecture scales better.
  • "You need positional embeddings or attention is permutation-invariant." True for standard attention, but RoPE, ALiBi, and learned embeddings all break the symmetry. Without some form of position information, transformers cannot represent order.

Frequently asked questions

What is the role of Q, K, V?

Each input token x is projected three times via learned linear maps: Q = xW_Q, K = xW_K, V = xW_V. Q (query) represents what this token is looking for. K (key) represents what this token offers as a match target. V (value) is the information to mix in if matched. The dot product Q dot K measures relevance — high when query and key align in the learned subspace. Softmax over keys gives attention weights; the weighted sum of V vectors becomes the output.

Why divide by √d in the softmax?

For two vectors with iid components of variance 1, their dot product has variance d. As d grows (often 64 or 128 per head), unscaled dot products have very large magnitudes; softmax saturates, producing near one-hot attention with vanishing gradients. Dividing by sqrt(d) keeps the dot product variance at 1, keeping softmax gradients well-behaved. The 1/sqrt(d) scaling is the difference between this paper's scaled dot-product attention and earlier Bahdanau-style attention.

How does multi-head attention work?

Instead of one big attention with d_model dimensions, split into h heads each of dimension d_head = d_model / h. Run h independent attention operations in parallel, then concatenate and project back. BERT-base uses h=12, d_head=64; GPT-3 uses h=96, d_head=128. Different heads learn to attend to different patterns: syntactic dependencies, semantic relations, position-only patterns, induction heads. Multi-head is mathematically equivalent to one bigger attention with block-diagonal projection matrices, but parallelizes much better on GPUs.

What is the O(n²) bottleneck?

Computing QK-transpose produces an n x n matrix of attention scores, requiring O(n^2 * d) FLOPs and O(n^2) memory. At sequence 8192, the attention matrix is 8192x8192 per head per batch — for GPT-3 with 96 heads and batch 32, that's 200+ GB of activations. This quadratic scaling has driven a decade of research into linear-attention variants (Performer, Linformer, Reformer), sliding-window approximations (Longformer, Mistral), and IO-aware kernels (FlashAttention).

How does FlashAttention reduce memory?

FlashAttention (Dao et al. 2022) realizes that the n x n attention matrix never needs to fully materialize. It tiles Q, K, V into blocks small enough to fit in GPU SRAM, computes softmax(QK^T) V incrementally using the online-softmax algorithm, and writes only the n x d output to global memory. Memory drops from O(n^2) to O(n); throughput rises 2-4x because most of the work stays in fast SRAM. FlashAttention-2 (2023) and FlashAttention-3 (2024) push it further; standard in PyTorch via scaled_dot_product_attention.

What is the difference between encoder, decoder, and cross-attention?

Encoder self-attention (BERT): each token attends to every other in the sequence — bidirectional context. Decoder self-attention (GPT): causal mask prevents attention to future tokens — required for autoregressive generation. Cross-attention (T5, original Vaswani 2017): Q comes from the decoder, K and V from the encoder — the decoder reads the encoder representation. Modern decoder-only LLMs (GPT-3+, Llama, Claude) use only causal self-attention and have eclipsed encoder-decoder architectures for most generation tasks.