Machine Learning
Neural Network Pruning
Throw away most of the model and keep almost all the accuracy
Neural network pruning deletes the weights that contribute least to a model's output, leaving a sparse network that runs faster and fits in less memory — often removing 80–90% of parameters with almost no loss in accuracy.
- Typical sparsity80–95%
- Accuracy loss< 1% (iterative)
- Cheapest criterion|weight| magnitude
- Speedup on dense HWstructured only
- Key resultLottery Ticket, 2019
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The idea: most of a network is dead weight
Train a modern neural network and a strange thing happens: when you look at the trained weights, most of them are tiny. A histogram of the weight magnitudes piles up near zero, with a thin tail of large values doing the real work. The intuition behind pruning is to take that observation literally — if a weight is close to zero, the signal flowing through it barely affects the output, so set it to zero and delete the connection.
This isn't a fringe trick. Over-parameterization is how deep networks train well in the first place: extra weights make the loss landscape easier to optimize. But once training has found a good solution, the scaffolding is no longer load-bearing. Pruning removes the scaffolding. The classic LeCun, Denker, and Solla result Optimal Brain Damage (1990) showed you could delete a large fraction of a network's parameters and recover the original accuracy with a little retraining — three decades before "model compression" became a deployment necessity.
Three things make pruning attractive on real hardware: a smaller model file (fewer non-zero weights to store), less memory bandwidth at inference (the bottleneck for most deployed models), and — if the sparsity is the right shape — fewer arithmetic operations. That last "if" is the whole game, and we'll return to it.
How pruning works: score, mask, fine-tune
Almost every pruning method follows the same three-step loop:
- Score every weight (or filter, or neuron) with an importance criterion. The cheapest is the absolute value
|w|. - Mask — set the lowest-scoring fraction to zero and freeze them there. The mask is a binary tensor
m ∈ {0,1}the same shape as the weights; the effective weight isw ⊙ m(elementwise product). - Fine-tune the surviving weights for a few epochs so they compensate for what was removed.
The crucial design choice is one-shot versus iterative. One-shot prunes to the target sparsity in a single step; it's fast but brittle. Iterative magnitude pruning (IMP) removes a small fraction per round — say 20% — fine-tunes to recover, then repeats until it hits the target. After k rounds at rate p per round, the density is (1 − p)^k: ten rounds at 20% leaves 0.8^10 ≈ 10.7% of the weights. Iterative pruning consistently reaches far higher sparsity at the same accuracy, because the surviving weights gradually take over the deleted ones' job.
The magnitude criterion as math. To reach global sparsity s, gather all weights, find the threshold τ equal to the s-th percentile of |w|, and set m_i = [ |w_i| ≥ τ ]. Finding the threshold is a quantile over n weights: O(n) with quickselect, or O(n log n) if you sort. Applying the mask is O(n). The expensive part is the fine-tuning that follows, which costs a normal training pass.
Structured vs unstructured — where the speedup hides
This is the single most misunderstood point about pruning, and the reason a "90% pruned" model often runs at exactly the same speed as the original.
Unstructured pruning zeros individual weights anywhere in the matrix. You get a sparse matrix riddled with holes. The problem: a standard dense matrix-multiply kernel on a GPU touches every entry regardless of whether it's zero. So unless you use a specialized sparse kernel (and many GPUs gain nothing below ~95% sparsity because of indexing overhead), an unstructured-pruned model is smaller on disk but exactly as slow at inference.
Structured pruning removes whole units — entire convolutional filters, channels, attention heads, or neurons. When you delete a filter, the output tensor genuinely loses a channel, the next layer's input shrinks, and the matrix is smaller in both dimensions. That gives real wall-clock speedup on any hardware, no special kernel required. The cost is that you have less granular control, so structured pruning usually can't reach the same sparsity at the same accuracy as unstructured.
A middle ground that ships in real silicon is 2:4 semi-structured sparsity: in every contiguous group of four weights, exactly two must be zero. NVIDIA's Ampere and later Tensor Cores have hardware that skips the zeros in this exact pattern, delivering up to a 2× matmul speedup at a fixed 50% sparsity. It's the rare case where unstructured-ish sparsity gets hardware acceleration.
Pruning approaches compared
| Magnitude (unstructured) | Structured (filter/channel) | 2:4 semi-structured | Optimal Brain Damage / Surgeon | Movement pruning | |
|---|---|---|---|---|---|
| Granularity | single weight | whole filter / neuron | 2 of every 4 weights | single weight | single weight |
| Criterion | |w| | filter norm / scale | |w| within each group | Hessian saliency | how fast |w| → 0 while fine-tuning |
| Max sparsity at <1% loss | ~90% | ~50–70% | fixed 50% | high (but costly) | high (best for transfer/fine-tune) |
| Speedup on dense HW | none (needs sparse kernel) | real, any HW | ~2× on Ampere+ Tensor Cores | none | none |
| Cost to compute scores | O(n), trivial | O(n) | O(n) | O(n) per weight w/ Hessian diag | training-time tracking |
| Best for | storage, research baselines | actual latency on GPU/CPU/mobile | NVIDIA inference | when accuracy is sacred | pruning while fine-tuning a pretrained model |
The headline trade-off: unstructured magnitude pruning maximizes the fraction of weights you can remove, but structured pruning maximizes the speedup you actually get. Choose based on whether your real constraint is model size or inference latency.
What the numbers actually say
- VGG-16 / ResNet on ImageNet, Han et al. (2015). Their "Deep Compression" pipeline pruned AlexNet by 9× and VGG-16 by 13× with no loss of accuracy, then combined pruning + quantization + Huffman coding to shrink the model 35–49× — VGG-16 from 552 MB to 11.3 MB.
- Iterative beats one-shot, decisively. Pruning a network to 90% sparsity in one shot with no fine-tuning typically collapses accuracy to near-random. The same 90% via iterative magnitude pruning with fine-tuning often loses under 1% — the entire difference is the score → mask → fine-tune loop run repeatedly.
- Lottery tickets at 10–20% density. Frankle & Carbin found winning sub-networks at 10–20% of the original weight count that matched or exceeded the full network's test accuracy when retrained from the original initialization.
- 2:4 sparsity gives ~2× matmul throughput on NVIDIA Ampere+ Tensor Cores at a fixed 50% weight sparsity — one of the only sparsity patterns with first-class hardware support.
- Memory bandwidth, not FLOPs, is usually the bottleneck. For most deployed transformers and CNNs, inference is memory-bound — so simply storing fewer non-zero weights and moving less data can speed things up even before you cut arithmetic.
JavaScript implementation
A self-contained global magnitude prune. We flatten every weight, find the magnitude threshold for the target sparsity with a quickselect-style percentile, then build the binary mask and apply it.
// weights: array of Float32Array layer tensors (flattened)
// sparsity: fraction of weights to zero, e.g. 0.9 for 90%
function globalMagnitudePrune(weights, sparsity) {
// 1. SCORE — collect |w| across all layers
const mags = [];
for (const layer of weights)
for (let i = 0; i < layer.length; i++) mags.push(Math.abs(layer[i]));
// 2. find the threshold = the (sparsity)-quantile of magnitudes
const k = Math.floor(mags.length * sparsity);
const tau = quickselect(mags, k); // O(n) average
// 3. MASK + apply — zero anything below threshold, return the masks
const masks = [];
let kept = 0, total = 0;
for (const layer of weights) {
const mask = new Uint8Array(layer.length);
for (let i = 0; i < layer.length; i++) {
total++;
if (Math.abs(layer[i]) >= tau) { mask[i] = 1; kept++; }
else { layer[i] = 0; } // prune
}
masks.push(mask);
}
return { masks, actualSparsity: 1 - kept / total };
}
// Hoare quickselect: k-th smallest value, average O(n), no full sort
function quickselect(a, k) {
let lo = 0, hi = a.length - 1;
while (lo < hi) {
const pivot = a[(lo + hi) >> 1];
let i = lo, j = hi;
while (i <= j) {
while (a[i] < pivot) i++;
while (a[j] > pivot) j--;
if (i <= j) { [a[i], a[j]] = [a[j], a[i]]; i++; j--; }
}
if (k <= j) hi = j; else if (k >= i) lo = i; else break;
}
return a[k];
}
// During fine-tuning, re-apply the mask after every optimizer step so
// pruned weights stay dead (gradients would otherwise revive them):
function reapplyMasks(weights, masks) {
for (let l = 0; l < weights.length; l++)
for (let i = 0; i < weights[l].length; i++)
if (masks[l][i] === 0) weights[l][i] = 0;
}
The detail that trips people up is the last function. The mask isn't a one-time deletion — backprop will happily push a "dead" weight back to a non-zero value on the next step, silently un-pruning your model. You must re-apply the mask after every gradient update so the zeros stay zero.
Python implementation (iterative magnitude pruning)
The full IMP loop in PyTorch-style pseudocode, including the lottery-ticket rewind. This is the recipe that reaches high sparsity at low accuracy cost.
import torch
def magnitude_mask(weight, sparsity):
"""Binary mask keeping the top (1 - sparsity) by |w|."""
flat = weight.abs().flatten()
k = int(flat.numel() * sparsity)
if k == 0:
return torch.ones_like(weight)
tau = flat.kthvalue(k).values # threshold, O(n) on GPU
return (weight.abs() >= tau).float()
def iterative_magnitude_prune(model, train, evaluate,
target_sparsity=0.9,
rounds=10, rewind_state=None):
masks = {n: torch.ones_like(p)
for n, p in model.named_parameters() if p.dim() > 1}
# per-round rate so (1 - p)**rounds == 1 - target_sparsity
keep = 1 - target_sparsity
p_round = 1 - keep ** (1 / rounds)
for r in range(rounds):
for n, param in model.named_parameters():
if n not in masks:
continue
# prune p_round of the *currently surviving* weights
live = (masks[n] > 0).float()
kept = int(live.sum().item())
k = int(kept * p_round)
if k > 0:
mags = param.detach().abs() * live # ignore dead
tau = mags[mags > 0].kthvalue(k).values
masks[n] = ((param.detach().abs() >= tau) & (live > 0)).float()
# LOTTERY TICKET: rewind survivors to their init, then retrain
if rewind_state is not None:
with torch.no_grad():
param.copy_(rewind_state[n] * masks[n])
else:
param.data.mul_(masks[n]) # just zero pruned
train(model, masks) # fine-tune
acc = evaluate(model)
sparsity = 1 - sum(m.sum() for m in masks.values()) \
/ sum(m.numel() for m in masks.values())
print(f"round {r}: sparsity={sparsity:.3f} acc={acc:.4f}")
return model, masks
# train() must re-apply the mask after each optimizer.step():
# optimizer.step()
# for n, p in model.named_parameters():
# if n in masks: p.data.mul_(masks[n])
Setting rewind_state to the network's original initialization turns plain iterative pruning into the lottery-ticket procedure: you keep the mask you discovered by training, but reset the surviving weights to their day-zero values before retraining. Frankle and Carbin's surprise was that this resets-and-retrains version still hits full accuracy, implying the winning sub-network was "lucky" at initialization, not just well-trained.
Variants worth knowing
Optimal Brain Damage / Optimal Brain Surgeon. The original 1990/1993 saliency methods. Instead of magnitude, they estimate each weight's effect on the loss using the Hessian (second derivatives). More accurate than magnitude, but computing or approximating the Hessian for millions of parameters is expensive — which is why magnitude pruning won by default.
Movement pruning (Sanh et al., 2020). When you prune during fine-tuning of a pretrained model, magnitude is the wrong signal — a weight that started large and is shrinking toward zero is actually becoming unimportant. Movement pruning scores weights by how their magnitude is changing, and beats magnitude pruning in transfer-learning settings like fine-tuning BERT.
SNIP and GraSP — pruning at initialization. These prune before any training, using a single batch to estimate which connections matter. Cheap, but generally weaker than train-then-prune.
The Lottery Ticket Hypothesis (Frankle & Carbin, 2019). Not a method so much as a discovery: dense networks contain sparse "winning ticket" sub-networks that train to full accuracy from the original initialization. It reframed pruning as finding a good sub-network rather than damaging a trained one.
Structured / channel pruning. Remove whole filters ranked by their L1/L2 norm or by a learned per-channel scaling factor (e.g. reusing the gamma of a batch-norm layer). The only family that reliably gives latency wins on commodity hardware.
Common bugs and edge cases
- Forgetting to re-apply the mask during fine-tuning. The number-one bug. Gradients revive pruned weights; your "90% sparse" model quietly becomes dense again. Re-multiply by the mask after every
optimizer.step(). - Expecting a speedup from unstructured sparsity on a dense GPU. A masked dense matmul does the same FLOPs as before. Without a sparse kernel or structured pruning, you save storage, not time.
- Per-layer vs global thresholds. A single global threshold can wipe out an entire small layer whose weights all happen to be tiny, breaking the forward pass. Either prune per-layer, or protect the first/last layers, which are usually the most sensitive.
- Pruning normalization and bias terms. Magnitude pruning should target weight matrices, not batch-norm scales or biases — the code above only prunes parameters with
dim > 1for exactly this reason. - One-shot pruning with no fine-tuning. Tempting because it's fast, but at high sparsity it usually destroys the model. Budget for the recovery training.
- Reporting sparsity but not real metrics. "95% sparse" is meaningless without the matching accuracy and the actual measured latency on the target device. Sparsity is the input; speed and accuracy are the outputs that matter.
Frequently asked questions
Does pruning actually make a model faster?
Only if the hardware can exploit the sparsity. Unstructured pruning zeros out individual weights, but a dense matrix multiply still touches every entry — so on a normal GPU a 90% sparse model runs at the same speed unless you use a sparse kernel. Structured pruning, which removes whole channels or filters, shrinks the actual tensor dimensions and gives real wall-clock speedups on any hardware.
What is the lottery ticket hypothesis?
Frankle and Carbin (2019) showed that inside a dense randomly-initialized network there exists a small sub-network — a winning ticket — that, when trained in isolation from the same initial weights, matches the full network's accuracy. You find it by training, pruning the small-magnitude weights, then rewinding the survivors to their original initialization and retraining.
How much can you prune before accuracy drops?
It varies by architecture, but with iterative magnitude pruning plus fine-tuning, many over-parameterized vision and language models tolerate 80–95% of weights removed with under 1% accuracy loss. Push past that and accuracy falls off a cliff. One-shot pruning to the same sparsity, with no fine-tuning, typically destroys the model.
What's the difference between pruning and quantization?
Pruning removes weights (sets them to zero and ideally drops them); quantization keeps every weight but stores each in fewer bits, e.g. 8-bit integers instead of 32-bit floats. They are complementary: a typical compression pipeline prunes first, then quantizes the survivors, then sometimes distills into a smaller student.
Why prune iteratively instead of all at once?
Removing a large fraction of weights in a single step shocks the network — the remaining weights have never been trained to compensate. Iterative pruning removes a small slice (say 20%), fine-tunes to recover, then repeats. The surviving weights gradually absorb the load, which is why iterative pruning reaches far higher sparsity at the same accuracy than one-shot pruning.
Is the smallest-magnitude weight always the least important?
No — magnitude is a cheap proxy, not the truth. A weight can be small yet sit on a steep part of the loss surface, so deleting it hurts a lot. Saliency methods like Optimal Brain Damage rank weights by their effect on the loss using second-order (Hessian) information, which is more accurate but far more expensive. In practice magnitude pruning is the default because it is nearly free and usually good enough.