Machine Learning
Knowledge Distillation
Compress a giant network into a tiny one by copying its hesitation, not just its answers
Knowledge distillation trains a small student network to mimic a large teacher's soft probability distribution — logits softened by a temperature T — transferring the teacher's "dark knowledge" so the student keeps most of the accuracy at a fraction of the size and latency.
- IntroducedHinton, Vinyals & Dean, 2015
- Core trickSoft targets at temperature T
- Lossα·T²·KL(soft) + (1−α)·CE(hard)
- Typical T2–20
- DistilBERT result~97% of BERT, 40% smaller, 60% faster
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The intuition: copy the teacher's hesitation
Train a small neural network on hard labels — the one-hot vectors that say "this image is a 7, period" — and it learns a thin slice of what's knowable. The label [0,0,0,0,0,0,0,1,0,0] tells the model nothing about how a 7 relates to a 1, or why a sloppy 7 might look like a 1 but never like an 8.
A big, well-trained teacher network knows all of this. Feed it that same 7 and it outputs something like P(7)=0.92, P(1)=0.05, P(9)=0.02, …. The teacher is 92% sure it's a 7 — but the structure of its uncertainty, the fact that its second guess is "1" and not "8", carries real information about the shape of the data. Geoffrey Hinton calls this dark knowledge.
Knowledge distillation is the idea, formalized by Hinton, Vinyals, and Dean in their 2015 paper Distilling the Knowledge in a Neural Network, that you can train a small student network to match the teacher's full probability vector — its hesitation included — instead of just the hard label. The student ends up far more accurate than the same network trained from scratch on hard labels, because it's learning from a much richer target.
How it works: softening with temperature
The problem with a confident teacher is that its softmax output is almost one-hot anyway. If P(7)=0.999, the wrong-class probabilities are too tiny to push gradients. The fix is temperature.
An ordinary softmax over logits z is p_i = exp(z_i) / Σ_j exp(z_j). The tempered softmax divides every logit by a temperature T first:
p_i(T) = exp(z_i / T) / Σ_j exp(z_j / T)
At T = 1 this is the normal softmax. As T grows the distribution flattens — the gaps between logits shrink, the small probabilities swell, and the relative ratios between wrong classes become visible. As T → ∞ the output approaches uniform. A typical choice is T between 2 and 20.
The recipe is then:
- Compute the teacher's soft targets
q = softmax(z_teacher / T)using the high temperature. - Compute the student's soft predictions
p = softmax(z_student / T)at the same T. - Minimize the divergence between
pandq(a "soft" loss). - Optionally also fit the true hard labels with an ordinary
T=1cross-entropy (a "hard" loss).
At inference time you throw the teacher away and run the student at T = 1.
The distillation loss and the T² factor
The combined loss is a weighted sum of the two objectives:
L = α · T² · KL( softmax(z_t/T) ‖ softmax(z_s/T) )
+ (1 − α) · CE( softmax(z_s), y_hard )
Two parts deserve explanation. First, the soft term is usually written as a KL divergence (or equivalently a cross-entropy between the soft distributions); minimizing it pulls the student's tempered output toward the teacher's. Second — and this trips everyone up — the soft term is multiplied by T².
Why T²? When you divide logits by T, the gradient of the soft loss with respect to a logit scales by 1/T². Without correction, cranking up the temperature would silently shrink the soft loss into irrelevance next to the hard loss. Multiplying by T² restores the gradient magnitude so the balance between soft and hard terms (set by α) stays meaningful as you tune T. This factor is straight from the original paper, and forgetting it is the single most common distillation bug.
Hinton et al. showed that in the high-temperature limit, matching tempered softmax outputs is approximately equivalent to matching logits directly — so distillation generalizes the older "logit matching" idea of Caruana et al. (2006), who first showed you could compress an ensemble into a single net.
Complexity. Distillation adds essentially nothing to the student's training cost beyond a single forward pass through the teacher per batch. If the teacher has cost C_t per example and the student C_s, training one batch costs O(C_t + C_s) for the forward passes plus the student's backward pass; the teacher is frozen, so it never backpropagates. Even better, the teacher's soft targets can be precomputed once and cached, dropping the per-epoch teacher cost to zero on subsequent passes — at the price of storing one probability vector per training example.
When to use it — and the tradeoffs
- Edge and mobile deployment. When the teacher is too big to run on a phone or in a tight latency budget, distill it into a student that fits.
- Ensemble compression. A 10-model ensemble is too expensive to serve; distill all ten into one student that captures most of the ensemble's accuracy.
- Label-scarce regimes. The teacher can label a large pool of unlabeled data with soft targets, giving the student more to learn from than the original labeled set.
- Faster, cheaper serving. A 40–60% latency cut at near-iso accuracy is often worth far more in production than the last fraction of a percent of accuracy.
The costs are real, though. You need a trained teacher first, so total compute goes up, not down — distillation trades training cost for inference cost. The student rarely matches the teacher exactly; you accept a small accuracy gap. And the technique is finicky: T and α need tuning, and a teacher that's too much larger than the student ("capacity gap") can actually hurt — the student can't follow a target it has no room to represent.
Distillation vs other compression methods
| Distillation | Pruning | Quantization | Low-rank factorization | NAS / from-scratch small | |
|---|---|---|---|---|---|
| What it edits | Trains a new small net | Removes weights/channels | Lowers weight precision | Factors weight matrices | Designs a small net directly |
| Needs a teacher? | Yes | No (uses original) | No | No | No |
| Architecture freedom | Full — any student shape | Constrained to original | Constrained to original | Constrained to original | Full |
| Typical size cut | 2–10× | 2–5× (structured) | 2–4× (FP32→INT8) | 1.5–3× | arbitrary |
| Latency benefit | High (fewer FLOPs) | Medium (sparsity HW-dependent) | High (INT8 kernels) | Medium | High |
| Accuracy retention | Often 95–99% of teacher | High at moderate sparsity | High to INT8, drops below | Moderate | Lower — no teacher signal |
| Stacks with the others? | Yes — distill then quantize | Yes | Yes | Yes | n/a |
These are complementary, not competing. A common production pipeline is: distill the big model into a compact student, then quantize that student to INT8, then prune any dead channels. Each step compounds the savings.
What the numbers actually say
- DistilBERT (Sanh et al., 2019) distills BERT-base into a 6-layer model with 40% fewer parameters (66M vs 110M), runs ~60% faster at inference, and retains about 97% of BERT-base's GLUE benchmark performance.
- The original MNIST experiment (2015) distilled a large net into a small one and recovered most of the accuracy gap. Hinton's striking demo: a student that had never seen the digit 3 during distillation still classified 3s at high accuracy, because the bias was correctable and the dark knowledge of how 3s relate to other digits leaked through the soft targets of other classes.
- Ensemble compression. The 2015 paper distilled a 10-model speech ensemble into a single model and recovered most of the ensemble's frame-accuracy gain over a single baseline model — at one-tenth the serving cost.
- Caching soft targets turns the per-epoch teacher cost to zero after the first pass, at a storage cost of one float vector of length
num_classesper training example — trivial for 10 classes, heavy for a 50,000-token vocabulary (where you store top-k logits instead).
JavaScript implementation
The arithmetic core fits in a few functions. Here's the tempered softmax, the KL term, and the combined gradient for one example — the pieces you'd wire into a training loop.
// Tempered softmax: divide logits by T, then softmax.
function softmaxT(logits, T = 1) {
const scaled = logits.map(z => z / T);
const max = Math.max(...scaled); // numerical stability
const exps = scaled.map(z => Math.exp(z - max));
const sum = exps.reduce((a, b) => a + b, 0);
return exps.map(e => e / sum);
}
// KL( q ‖ p ) — teacher q against student p, both tempered.
function klDivergence(q, p) {
let kl = 0;
for (let i = 0; i < q.length; i++) {
if (q[i] > 0) kl += q[i] * Math.log(q[i] / Math.max(p[i], 1e-12));
}
return kl;
}
// Combined distillation loss for one example.
function distillLoss(studentLogits, teacherLogits, hardLabel, { T = 4, alpha = 0.9 } = {}) {
const q = softmaxT(teacherLogits, T); // teacher soft targets
const pSoft = softmaxT(studentLogits, T); // student at temperature T
const pHard = softmaxT(studentLogits, 1); // student at T=1 for hard CE
const soft = alpha * (T * T) * klDivergence(q, pSoft); // note the T² scaling
const hard = (1 - alpha) * -Math.log(Math.max(pHard[hardLabel], 1e-12));
return soft + hard;
}
// Gradient of the SOFT term w.r.t. student logits.
// For cross-entropy between softmaxes, the gradient is (p - q),
// and the T² prefactor combined with the 1/T from d(z/T) leaves a net T.
function softGrad(studentLogits, teacherLogits, T = 4, alpha = 0.9) {
const q = softmaxT(teacherLogits, T);
const p = softmaxT(studentLogits, T);
return p.map((pi, i) => alpha * T * (pi - q[i])); // T² · (1/T) = T
}
Two subtleties worth flagging. The - max subtraction inside softmaxT prevents exp overflow — never skip it. And the softGrad prefactor is T, not T²: the loss carries T², but differentiating the tempered softmax brings down a factor of 1/T, leaving a net T. Getting this off by a factor of T is a classic silent bug.
Python (PyTorch) implementation
In PyTorch the whole loss is a few lines, and the framework handles the gradients. Note log_softmax for the student and plain softmax for the teacher, which is what KLDivLoss expects.
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, hard_labels,
T=4.0, alpha=0.9):
# Soft term: KL between tempered distributions, scaled by T**2.
soft_student = F.log_softmax(student_logits / T, dim=-1)
soft_teacher = F.softmax(teacher_logits / T, dim=-1)
soft_loss = F.kl_div(soft_student, soft_teacher,
reduction="batchmean") * (T * T)
# Hard term: ordinary cross-entropy on the true labels at T=1.
hard_loss = F.cross_entropy(student_logits, hard_labels)
return alpha * soft_loss + (1.0 - alpha) * hard_loss
def train_step(student, teacher, x, y, optimizer, T=4.0, alpha=0.9):
teacher.eval()
with torch.no_grad(): # teacher is frozen — no grad
teacher_logits = teacher(x)
student.train()
student_logits = student(x)
loss = distillation_loss(student_logits, teacher_logits, y, T, alpha)
optimizer.zero_grad()
loss.backward() # only the student updates
optimizer.step()
return loss.item()
The torch.no_grad() block on the teacher is not optional — without it you waste memory building a graph for a network you never update, and large teachers will OOM. For multi-epoch training over a fixed set, cache teacher_logits to disk after epoch one and skip the teacher forward pass entirely.
Variants worth knowing
Feature / hint distillation (FitNets, 2015). Instead of only matching final outputs, match intermediate activations — the student's hidden layer is trained to regress onto a "hint" layer of the teacher (via a small projection). This gives a richer signal and helps very deep, thin students.
Attention transfer. For CNNs and transformers, match the teacher's attention maps rather than (or in addition to) logits. The student learns where the teacher looks, which transfers spatial/positional structure that logits alone don't capture.
Self-distillation. Teacher and student share the same architecture; a model distills into a copy of itself (sometimes across training generations, "Born-Again Networks"). Surprisingly, the student often beats the teacher, because the soft targets act as a regularizer.
Online / mutual distillation (Deep Mutual Learning). No pre-trained teacher: a cohort of students train together and distill from each other simultaneously. Cheaper end-to-end since you skip the separate teacher-training phase.
Data-free distillation. When the teacher's training data is private, synthesize inputs (e.g. by inverting the teacher or with a GAN) so the student can be distilled without ever touching the real dataset.
Common bugs and edge cases
- Forgetting the T² scale. Without it, raising the temperature silently kills the soft loss's contribution. Symptom: distillation "does nothing" at high T.
- Mismatched temperatures. The teacher and student soft outputs must use the same T. Tempering only one side compares incompatible distributions.
- Leaving the teacher in train mode. Dropout and BatchNorm running stats in the teacher will jitter its soft targets batch to batch. Always call
teacher.eval()and wrap inference inno_grad(). - Using hard-label CE at temperature T. The hard-label cross-entropy must be computed at
T = 1, against the real labels — not the tempered logits. - Capacity gap. A student that's far too small to represent the teacher's function distills poorly. If accuracy stalls, the student may simply lack capacity — or you need an intermediate "teacher assistant" model to bridge the gap.
- log(0) and numerical underflow. Clamp probabilities away from zero (e.g.
max(p, 1e-12)) before taking logs, and subtract the max logit before exponentiating, or you'll get NaNs at high confidence. - Evaluating the student at temperature T. Temperature is a training-time device. At deployment, run the student at
T = 1— the calibrated probabilities you actually want.
Frequently asked questions
What is the role of temperature in knowledge distillation?
Temperature T divides the logits before the softmax. A high T (typically 2–20) flattens the distribution so small probabilities — the relative likelihoods of the wrong classes — grow large enough for the student to learn from. At T=1 you recover the ordinary softmax; as T→∞ the distribution approaches uniform.
Why is the distillation loss multiplied by T squared?
Softening the logits by T shrinks the gradients of the soft-target loss by a factor of 1/T². Multiplying the soft loss by T² rescales those gradients back to roughly the same magnitude as the hard-label gradients, so the two terms stay balanced when you change T. This factor comes directly from the Hinton, Vinyals & Dean 2015 paper.
What is dark knowledge?
Dark knowledge is the information hidden in a teacher's wrong-class probabilities — for example, a digit "2" that the teacher thinks is 0.001 likely a "7" and 0.0001 likely a "3". That ratio encodes which classes look similar, a far richer signal than the one-hot label "2". Matching it is why distillation beats training the small model alone.
Does the student have to be the same architecture as the teacher?
No. Distillation only requires that both produce a probability vector over the same classes. A 6-layer transformer can distill from a 24-layer one, a CNN from an ensemble, or DistilBERT (6 layers) from BERT-base (12 layers). The architectures are independent; only the output space must match.
How much accuracy do you lose by distilling?
It depends on the compression ratio, but the headline results are modest. DistilBERT keeps about 97% of BERT-base's GLUE score with 40% fewer parameters and 60% faster inference. On ImageNet, a distilled ResNet often closes most of the gap to its teacher that direct training leaves open.
Is distillation the same as pruning or quantization?
No — they are complementary compression techniques. Pruning removes weights and quantization lowers their precision, both editing an existing network. Distillation trains a brand-new smaller network from scratch using the teacher as a soft-label oracle. In practice you often distill first, then quantize the student.