Machine Learning
FlashAttention
Exact attention that never builds the matrix it's famous for
FlashAttention computes exact attention without ever materializing the N×N score matrix, tiling queries, keys, and values through fast on-chip SRAM with an online softmax — cutting memory from O(N²) to O(N) and running 2–4× faster on long sequences.
- Extra memoryO(N), not O(N²)
- FLOPsSame as standard
- HBM trafficO(N²d²/M)
- ResultExact (not approximate)
- Speedup (long seq)2–4×
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The problem: attention is memory-bound, not compute-bound
Self-attention is the heart of every transformer, and the textbook recipe is three lines: score every query against every key (S = QKᵀ), softmax each row into probabilities (P = softmax(S)), then mix the values (O = PV). The catch hides in the middle: S and P are both N×N, where N is the sequence length. At N = 8,192 tokens, one head's score matrix is 67 million floats — 256 MB in fp32 — and a model has dozens of heads across dozens of layers.
The instinct is that this is a FLOP problem. It isn't. The matrix multiplies are cheap relative to what kills you: writing those 256 MB of scores out to HBM (the GPU's large but slow high-bandwidth memory), reading them back to softmax, writing the probabilities, reading them again to multiply by V. On an A100, HBM delivers ~1.5–2 TB/s while the on-chip SRAM runs at ~19 TB/s — an order of magnitude faster, but only ~20 MB total. Standard attention spends most of its wall-clock time shuttling the N×N matrix across that slow lane. The arithmetic units sit idle waiting for memory.
FlashAttention, introduced by Tri Dao and colleagues at Stanford in 2022 ("FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"), is the realization that you can compute the exact same output while never writing the N×N matrix to HBM at all. You tile Q, K, and V into blocks small enough to fit in SRAM, stream them through, and fold each block's contribution into a running result using an online softmax. The matrix exists only as transient tiles in fast memory — it is never assembled in full.
The key trick: online (streaming) softmax
The obstacle to tiling is softmax. The standard softmax over a row needs the whole row at once, because the denominator sums exp(s_j) over all keys, and the numerically-safe version subtracts the row maximum first to avoid exp overflow. If you only have one tile of keys, you don't yet know the global maximum or the global sum.
Online softmax (the same idea Milakov & Gimelshein described in 2018) fixes this by carrying two running statistics per query row as you walk left to right across key tiles: the running max m and the running denominator ℓ. When a new tile reveals a larger maximum, you rescale everything computed so far by exp(m_old − m_new) so the bookkeeping stays consistent:
for each key tile j:
m_new = max(m_old, rowmax(S_j)) # update running max
P_j = exp(S_j - m_new) # safe exponent for this tile
ℓ_new = exp(m_old - m_new)·ℓ_old + rowsum(P_j)
O = exp(m_old - m_new)·O + P_j · V_j # rescale old output, add new
m_old, ℓ_old = m_new, ℓ_new
O = O / ℓ_final # one final normalization
Because the rescale factor exp(m_old − m_new) is applied uniformly to both the accumulated output and the denominator, the algebra is exact: after the last tile, O / ℓ equals softmax(QKᵀ)V to floating-point precision. There is no approximation anywhere — this is the single most misunderstood fact about FlashAttention.
The mechanism: tiling through the memory hierarchy
Concretely, the forward kernel partitions the sequence into blocks. With SRAM size M and head dimension d, FlashAttention-1 chose a key/value block size Bc = ⌈M/4d⌉ and a query block size Br = min(⌈M/4d⌉, d). The outer loop walks over key/value blocks; the inner loop walks over query blocks (FlashAttention-2 later flipped these — see variants). For each (query-block, key-block) pair, the kernel:
- Loads
Q_i,K_j,V_jfrom HBM into SRAM. - Computes the tile scores
S_ij = Q_i K_jᵀentirely in SRAM. - Updates the running
m,ℓ, and output accumulatorO_iusing the online-softmax recurrence. - Writes only the final
O_i(N×d) and the scalar statistics back to HBM.
The headline complexity result: standard attention moves Θ(Nd + N²) bytes through HBM; FlashAttention moves Θ(N²d²/M). Since the SRAM size M is typically much larger than d² (M is ~100 KB while d² for d = 128 is ~16K floats), that quotient is a large constant-factor reduction in slow-memory traffic. The arithmetic is unchanged — both do Θ(N²d) FLOPs — so the speedup comes purely from the IO. The paper proves this N²d²/M bound is optimal among any tiling within constant factors.
The backward pass uses the same idea, plus a second one: recomputation. Rather than stashing the N×N probability matrix for gradients (the usual autograd approach), FlashAttention saves only Q, K, V, the output O, and the per-row statistics, then recomputes the score and probability tiles on the fly during the backward pass. Extra FLOPs, but it avoids reading a giant matrix from HBM — and since attention is bandwidth bound, recompute-and-stay-in-SRAM still wins.
When FlashAttention helps (and when it doesn't)
- Long sequences. The win grows with N. At N = 512 the gain is modest; at N = 8K–64K it is the difference between fitting in memory and an out-of-memory crash. Long-context LLMs (32K–1M tokens) are flatly impossible without it.
- Training and prefill. Both are compute-dense over many tokens at once, so the N×N traffic dominates and tiling pays off handsomely.
- Causal / masked attention. FlashAttention skips fully-masked key blocks entirely, roughly halving work on causal masks where the upper triangle contributes nothing.
Where it helps least: very short sequences (N ≤ 128), where the N×N matrix already fits in cache and kernel-launch overhead dominates; and single-token decode steps, where N_query = 1 and the bottleneck shifts from the score matrix to streaming the KV cache — that's the regime FlashAttention variants like FlashDecoding and paged KV caches target instead.
FlashAttention vs other ways to scale attention
| FlashAttention | Standard attention | Sparse (BigBird/Longformer) | Linear (Performer) | Linformer | Multi-Query / GQA | |
|---|---|---|---|---|---|---|
| Output exactness | Exact | Exact | Approximate (sparsity) | Approximate (kernel) | Approximate (low-rank) | Exact (fewer KV heads) |
| Compute | O(N²d) | O(N²d) | O(N·w·d) | O(Nd²) | O(Nkd) | O(N²d) |
| Extra memory | O(N) | O(N²) | O(N·w) | O(Nd) | O(Nk) | O(N²) |
| HBM traffic | O(N²d²/M) | O(N² + Nd) | O(N·w) | O(Nd²) | O(Nkd) | O(N² + Nd) |
| Quality vs full attention | Identical | Baseline | Small drop | Noticeable drop | Task-dependent | Tiny drop |
| Solves | Memory + speed | — | Asymptotic compute | Asymptotic compute | Asymptotic compute | Decode KV size |
| Composes with FlashAttention? | — | — | No (different math) | No (different math) | No (different math) | Yes (orthogonal) |
The crucial distinction: the sparse and linear methods change the math to cut the asymptotic N² and accept some accuracy loss. FlashAttention keeps the math exactly the same and attacks the memory system. That's why it became a drop-in default — there's no quality trade-off to argue about. Multi-Query and Grouped-Query Attention are orthogonal (they shrink the KV cache for decode) and routinely run on top of FlashAttention.
What the numbers actually say
- Memory: O(N) vs O(N²). For one head at N = 64K and d = 64, the standard score matrix is 64K × 64K × 2 bytes (bf16) ≈ 8.6 GB per head. FlashAttention stores the N×d output plus two N-length statistic vectors — about 8.4 MB. That's a ~1,000× reduction, and it's what makes 64K-token context fit on a single GPU.
- Speed: 2–4× end-to-end. The original paper reported ~3× faster GPT-2 training and 2.4× on long-range benchmarks. FlashAttention-2 roughly doubled the kernel throughput again, reaching ~50–73% of A100 peak FLOPs (standard attention sits well under 20%).
- HBM bytes moved. Standard attention on N = 4K, d = 64 moves on the order of the 4K×4K matrix twice (write + read) per of S and P. FlashAttention moves it once through SRAM tiles — the paper measured up to a 9× reduction in HBM accesses on representative shapes.
- Recomputation cost. The backward pass redoes the QKᵀ and softmax tiles — roughly 2× the attention FLOPs of a stored-matrix backward — yet still runs faster overall because it's no longer waiting on HBM reads of an N×N matrix.
- FlashAttention-3 on H100. Using Hopper's async copy (TMA), warp-specialization, and FP8, FA-3 reaches up to ~75% of H100 peak (≈740 TFLOP/s in FP16, and ~1.2 PFLOP/s in FP8), a further ~1.5–2× over FA-2 on the same hardware.
JavaScript implementation (the online-softmax core)
You won't run real FlashAttention in JavaScript — it's a CUDA kernel — but the algorithm is short and the value is seeing that the streaming version produces exactly the same numbers as the naive one. Here a single query streams over key/value tiles:
// One query row q (length d) attends over keys K (n×d) and values V (n×d),
// processed Bc keys at a time, never building the full length-n score row.
function flashAttentionRow(q, K, V, Bc = 4) {
const n = K.length, d = q.length;
let m = -Infinity; // running max of scores
let l = 0; // running denominator (sum of exp)
let o = new Array(d).fill(0); // running output accumulator
for (let start = 0; start < n; start += Bc) {
const end = Math.min(start + Bc, n);
// 1) tile scores s_j = q · k_j (lives only in this loop)
const s = [];
for (let j = start; j < end; j++) {
let dot = 0;
for (let t = 0; t < d; t++) dot += q[t] * K[j][t];
s.push(dot);
}
// 2) new running max
const tileMax = Math.max(...s);
const mNew = Math.max(m, tileMax);
// 3) exp this tile against the NEW max; rescale factor for old state
const correction = Math.exp(m - mNew); // = exp(m_old - m_new), 1 on first tile
const p = s.map(v => Math.exp(v - mNew));
const tileSum = p.reduce((a, b) => a + b, 0);
// 4) fold tile into running denominator and output
l = correction * l + tileSum;
for (let t = 0; t < d; t++) {
let acc = 0;
for (let j = start; j < end; j++) acc += p[j - start] * V[j][t];
o[t] = correction * o[t] + acc; // rescale old, add new
}
m = mNew;
}
return o.map(v => v / l); // single final normalization
}
Compare the output of flashAttentionRow(q, K, V, 2) with a naive softmax-then-multiply over the whole row and you'll get the same vector (modulo last-bit floating-point reordering) regardless of the tile size Bc. The tile size is purely a memory/performance knob, never a correctness knob — that's the whole point.
Python implementation (full tiled forward pass)
Here is the tiled forward over a whole batch of queries with NumPy, mirroring the two-loop structure of the real kernel (outer loop over key/value blocks, inner over query blocks) and tracking the per-row statistics:
import numpy as np
def flash_attention(Q, K, V, Br=32, Bc=32):
"""Exact attention via tiling + online softmax. Q,K,V: (N, d)."""
N, d = Q.shape
O = np.zeros((N, d), dtype=np.float32)
m = np.full(N, -np.inf, dtype=np.float32) # running row max
l = np.zeros(N, dtype=np.float32) # running denominator
for j in range(0, N, Bc): # outer: key/value blocks
Kj, Vj = K[j:j+Bc], V[j:j+Bc]
for i in range(0, N, Br): # inner: query blocks
Qi = Q[i:i+Br]
Sij = Qi @ Kj.T # (Br, Bc) tile — never stored globally
m_old = m[i:i+Br]
m_new = np.maximum(m_old, Sij.max(axis=1))
P = np.exp(Sij - m_new[:, None]) # (Br, Bc)
corr = np.exp(m_old - m_new) # rescale factor for old state
l[i:i+Br] = corr * l[i:i+Br] + P.sum(axis=1)
O[i:i+Br] = corr[:, None] * O[i:i+Br] + P @ Vj
m[i:i+Br] = m_new
return O / l[:, None] # final normalization
# sanity: identical to the naive version
def naive_attention(Q, K, V):
S = Q @ K.T
P = np.exp(S - S.max(axis=1, keepdims=True))
return (P / P.sum(axis=1, keepdims=True)) @ V
rng = np.random.default_rng(0)
Q, K, V = (rng.standard_normal((128, 64)).astype(np.float32) for _ in range(3))
assert np.allclose(flash_attention(Q, K, V), naive_attention(Q, K, V), atol=1e-4)
The assert passes for any block sizes Br, Bc. In production you'd add causal masking (skip key blocks with j > i_max), softmax scaling by 1/√d, and dropout — but the skeleton above is the algorithm.
Variants worth knowing
FlashAttention-2 (2023). Same exact result, but re-engineered for the GPU. It swaps the loop order so each query block is owned by one thread block and parallelizes across the sequence dimension; it cuts the number of non-matmul operations (the rescaling work) because those run on the slow special-function units; and it improves warp partitioning to reduce shared-memory traffic. Net effect: roughly 2× the throughput of FA-1, hitting 50–73% of A100 peak.
FlashAttention-3 (2024). Targets the Hopper architecture (H100). It overlaps the GEMM (matrix multiply) and softmax across warps using warp-specialization and the Tensor Memory Accelerator (TMA) for asynchronous loads, and supports FP8. It reaches ~75% of H100 peak and brings low-precision attention with controlled error.
FlashDecoding / FlashDecoding++. For autoregressive decode, where there's a single query token but a long KV cache, the original parallelization starves the GPU. FlashDecoding adds a parallel reduction over the key dimension, splitting the long KV sequence across many thread blocks and combining partial softmaxes — exactly the online-softmax merge, applied across blocks.
Block-sparse FlashAttention. Combines tiling with a sparsity mask at the block level: entire (query-block, key-block) pairs that the mask zeroes out are skipped. This is how FlashAttention plugs into sparse-attention patterns while keeping the IO-aware kernel.
Memory-efficient attention (Rabe & Staats, 2021). A concurrent, framework-level idea that achieved O(N) memory via the same chunked online softmax in pure JAX, without a custom kernel — it cut memory but not wall-clock time, since it didn't control the SRAM tiling. FlashAttention's contribution was making it IO-optimal and fast.
Common bugs and edge cases
- Forgetting the rescale on the output, not just the denominator. When the running max changes, you must multiply both
ℓand the accumulatedObyexp(m_old − m_new). Rescaling only the denominator silently produces wrong outputs that still look plausible. - An all-masked query row. With causal masking, the very first query may have zero valid keys in a block;
mstays −∞ andℓstays 0, so the final divide is 0/0. Production kernels guard this by treating fully-masked rows specially. - Dropping the 1/√d scale. Scores must be divided by √d before the softmax. With tiling it's easy to forget to fold the scale into the tile scores consistently; do it on
S_ijbefore computing the tile max. - Tile size as a correctness lever. A surprising number of buggy reimplementations get different answers at different block sizes. If your output depends on
Bc, your rescaling is wrong — the math is block-size invariant by construction. - Using it for tiny sequences. Below ~256 tokens the kernel-launch and statistics overhead can make FlashAttention slower than a fused naive kernel. It's a long-sequence win.
- Expecting it to reduce FLOPs. It doesn't, and the backward pass actually does more (recomputation). The win is entirely in memory traffic and peak memory — benchmark wall-clock, not FLOP counts.
Frequently asked questions
Is FlashAttention an approximation of attention?
No. FlashAttention computes the exact same output as standard attention, bit-for-bit up to floating-point reordering. It is not a sparse or low-rank approximation like Linformer or Performer — it is an IO-aware re-implementation of the same math that never materializes the full N×N score matrix.
How does FlashAttention compute softmax without the full score row?
It uses an online (streaming) softmax. As it walks across key/value tiles, it keeps a running row maximum m and a running denominator ℓ, and rescales the partial output by exp(m_old − m_new) whenever a new tile pushes the maximum higher. The final result is numerically identical to softmax over the whole row.
Why is FlashAttention faster if it does the same FLOPs?
Attention is memory-bandwidth bound, not compute bound. Standard attention writes the N×N scores and softmax probabilities to slow HBM (high-bandwidth memory) and reads them back — that traffic dominates the runtime. FlashAttention keeps tiles in on-chip SRAM, so it does the same multiply-adds but moves O(N²d²/M) bytes through HBM instead of O(N² + Nd), where M is the SRAM size.
What is the memory complexity of FlashAttention?
O(N) extra memory for the forward pass — you only store the N×d output plus the per-row softmax statistics (m and ℓ), never the N×N matrix. Standard attention needs O(N²). For N = 64K this is the difference between a few hundred megabytes and tens of gigabytes per head.
What is the difference between FlashAttention-1, 2, and 3?
FlashAttention-1 (2022) introduced tiling plus online softmax. FlashAttention-2 (2023) cut non-matmul FLOPs, parallelized over the sequence dimension, and improved warp partitioning, roughly doubling throughput to ~70% of peak FLOPs on A100. FlashAttention-3 (2024) targets Hopper (H100) with asynchrony (TMA, warp-specialization) and FP8, reaching up to ~75% of peak on H100.
How does FlashAttention save memory on the backward pass?
It does not store the attention matrix for the backward pass. Instead it recomputes the score and probability tiles on the fly from Q, K, V and the saved softmax statistics. This recomputation costs extra FLOPs but avoids reading a huge matrix from HBM, which is still a net win because attention is bandwidth bound.