Machine Learning
Transfer Learning
Don't train from scratch — borrow a model that already learned the hard part
Transfer learning reuses a model pretrained on a giant dataset and adapts it to a new, smaller task by freezing the learned feature layers and retraining only the top — reaching high accuracy with 10-1000× less data and compute.
- Data needed10-1000× less
- Trainable params (feature extraction)just the head
- Fine-tune learning rate10-100× smaller
- What transfersearly/generic layers
- Main failure modecatastrophic forgetting
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The intuition: don't relearn edges
Training a deep network from scratch means learning everything — that pixels form edges, that edges form textures, that textures form eyes and wheels and faces — purely from your labeled examples. For a 1.2-million-image dataset like ImageNet, that's reasonable. For your 800-photo dataset of factory defects, it's hopeless: the network will memorize the 800 photos and generalize to nothing.
Transfer learning sidesteps the problem. Someone already paid the cost of learning "what edges and textures look like" by training a network on ImageNet (or, for text, training BERT on 3.3 billion words, or GPT-style models on trillions of tokens). Those low-level features are generic — a network that recognizes the fur texture of a cat is most of the way to recognizing the fur of a dog, a wolf, or a fox it has never seen. You download that pretrained network, throw away its final classification layer, bolt on a fresh layer sized for your classes, and train. Because the hard, data-hungry feature learning is already done, your small dataset only has to learn the easy part: how to combine existing features into your specific labels.
The empirical foundation comes from Yosinski et al. (2014), "How transferable are features in deep neural networks?", which measured layer-by-layer transferability and showed that the first layers of a vision CNN are almost task-independent while the last layers are highly specific. That paper is why "freeze the bottom, retrain the top" is the default recipe.
The mechanism, layer by layer
A deep network is a stack of layers f₁, f₂, …, f_L, where the output is f_L(…f₂(f₁(x))). Pretraining on a source task learns weights θ₁…θ_L. Transfer learning splits this stack at some cut point k:
- Backbone (layers
1…k): the pretrained feature extractor. Its weights are frozen —requires_grad = False— so backprop computes no gradients for them and they never change. - Head (layers
k+1…new): a freshly initialized output layer (or small MLP) sized to your task's number of outputs. Only these weights are trainable.
Two knobs define the spectrum of approaches:
- Feature extraction — set
k = L: freeze the whole backbone, train only the new head. The backbone is now a fixed function that maps each input to an embedding; you're really just training a logistic regression on top of frozen embeddings. You can even precompute every embedding once and cache it, turning training into seconds. - Fine-tuning — set
k < L: unfreeze the top few backbone layers (or all of them) and train them together with the head, at a much smaller learning rate. This lets the high-level features specialize to your data instead of staying generic.
The cost asymmetry is the whole point. In feature extraction the backbone still runs a full forward pass — so inference cost is unchanged — but the backward pass and optimizer step touch only the head's parameters. If the backbone is 25M parameters and the head is 2,000, you're optimizing 0.008% of the model. Gradient computation through frozen layers is skipped entirely, so a training step is dominated by the forward pass alone.
Why small learning rate when fine-tuning? The pretrained weights sit in a good basin of the loss landscape. A large gradient step — driven by a freshly initialized, randomly wrong head — would knock the backbone out of that basin before it ever stabilizes. The standard fix is a learning rate 10-100× smaller than pretraining (e.g. 1e-5 vs 1e-3), often with discriminative learning rates: tiny for early layers, larger toward the head, because the early layers need to move least.
When to use which approach
The decision is driven by two variables: how much target data you have, and how similar the target domain is to the source. The classic 2×2 (from the CS231n course notes) is:
- Small data, similar domain → feature extraction. Freeze everything, train a linear head. A small dataset can't safely update millions of weights, but the features already fit.
- Large data, similar domain → fine-tune the whole network. You have enough data to update everything without overfitting, and you'll squeeze out extra accuracy.
- Small data, different domain → freeze early layers only, train head + a few mid layers. The generic edges still help; the specific features need to change.
- Large data, different domain → fine-tune everything, but pretrained init still beats random init by speeding convergence — even a distant source gives a better starting point than noise.
Skip transfer learning entirely when your domain shares no low-level statistics with any available pretrained model (raw tabular data, certain scientific signals) — there the frozen features are noise, and a compact from-scratch model wins. That's negative transfer.
Transfer learning vs from-scratch vs other adaptation methods
| From scratch | Feature extraction | Full fine-tuning | LoRA / adapters | Linear probing | |
|---|---|---|---|---|---|
| Trainable params | 100% (all) | head only (<0.1%) | 100% | 0.1-1% (low-rank deltas) | 1 linear layer |
| Target data needed | millions | 20-1000/class | thousands+/class | hundreds-thousands | 20-500/class |
| Training cost | days-weeks (GPUs) | seconds-minutes | minutes-hours | minutes | seconds |
| Peak accuracy | high (if huge data) | good | highest | near full fine-tune | baseline |
| Storage per task | full model | small head | full model copy | tiny delta (MBs) | tiny head |
| Catastrophic forgetting risk | n/a | none (frozen) | high | low | none |
| Best for | novel domain, big data | tiny dataset, fast iteration | moderate data, max accuracy | many tasks, one big model | quick baseline / probing |
The headline trade-off: feature extraction is the cheapest and safest but caps your accuracy at "whatever a linear classifier on frozen features can do." Full fine-tuning reaches the highest accuracy but risks catastrophic forgetting and stores a full model copy per task. Parameter-efficient methods like LoRA were invented to get most of the fine-tuning accuracy while storing only megabytes of weight deltas per task — critical when serving one base model adapted to hundreds of customers.
What the numbers actually say
- Data efficiency: 10-1000× less. A from-scratch ImageNet CNN needs ~1.2M labeled images. A fine-tuned one reaches comparable transfer accuracy on a downstream task like Oxford Flowers (102 classes) with roughly 1,000-2,000 labeled images — three orders of magnitude fewer.
- Compute: pretraining is the expensive part, and you skip it. Pretraining BERT-Large took ~4 days on 16 TPUs; fine-tuning it for a GLUE task takes ~1 hour on a single GPU. You amortize one massive pretraining run across thousands of downstream fine-tunes.
- Feature extraction is near-free. With the backbone frozen, you can precompute embeddings once (one forward pass over your data) and then train the head in seconds — for an 800-image dataset, the head trains in well under a second on a CPU.
- Storage with LoRA: ~10,000× smaller per task. A 7-billion-parameter model is ~14 GB in fp16; a typical LoRA adapter for it is a few megabytes, so you can store hundreds of task-specific adapters in the space of one base model.
- Fine-tune learning rate: typically 1e-5 to 5e-5 for transformers, vs 1e-3 to 1e-4 for pretraining — a 10-100× reduction that prevents the pretrained weights from being washed out.
JavaScript implementation
With TensorFlow.js you load a pretrained MobileNet, chop off its classifier, and train a small head on the frozen embeddings — the canonical "feature extraction" recipe that powers in-browser "teachable machine" demos.
import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';
// 1. Load the pretrained backbone (trained on 1.2M ImageNet images).
const base = await mobilenet.load({ version: 2, alpha: 1.0 });
// 2. Grab an intermediate layer as a frozen feature extractor.
// 'global_average_pooling2d_1' yields a 1280-D embedding per image.
const truncated = base.model;
const cutLayer = truncated.getLayer('global_average_pooling2d_1');
const backbone = tf.model({ inputs: truncated.inputs, outputs: cutLayer.output });
backbone.trainable = false; // FREEZE — no gradients flow here
// 3. Precompute embeddings once (the expensive forward pass, done a single time).
function embed(imgTensor) {
return tf.tidy(() => backbone.predict(imgTensor)); // [batch, 1280]
}
// 4. A fresh head sized for OUR task (say 3 defect classes).
const NUM_CLASSES = 3;
const head = tf.sequential();
head.add(tf.layers.dense({ inputShape: [1280], units: 64, activation: 'relu' }));
head.add(tf.layers.dense({ units: NUM_CLASSES, activation: 'softmax' }));
head.compile({ optimizer: tf.train.adam(1e-3), loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
// 5. Train ONLY the head on cached embeddings — seconds, not hours.
async function train(images, labels) {
const xs = embed(images); // frozen backbone -> embeddings
const ys = tf.oneHot(labels, NUM_CLASSES);
await head.fit(xs, ys, { epochs: 20, batchSize: 16, shuffle: true });
xs.dispose(); ys.dispose();
}
// 6. Inference: backbone embedding -> head prediction.
function predict(img) {
return tf.tidy(() => head.predict(embed(img)));
}
The key line is backbone.trainable = false. Without it, the 3.5M-parameter MobileNet would update on your handful of images and immediately overfit. With it, you're fitting a 64-unit layer on top of features that already know what textures look like.
Python implementation
In PyTorch the two-phase recipe — freeze and feature-extract, then unfreeze the top and fine-tune — is just a matter of toggling requires_grad and using two optimizers with different learning rates.
import torch, torch.nn as nn
from torchvision import models
NUM_CLASSES = 3
# 1. Load a pretrained ResNet-50 (ImageNet weights).
net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# 2. FREEZE the whole backbone.
for p in net.parameters():
p.requires_grad = False
# 3. Replace the 1000-class head with OUR head. New layers default to
# requires_grad=True, so only these will train.
in_features = net.fc.in_features # 2048 for ResNet-50
net.fc = nn.Linear(in_features, NUM_CLASSES)
# ---- Phase 1: feature extraction (backbone frozen) ----
opt = torch.optim.Adam(net.fc.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
def train_one_epoch(loader):
net.train()
for x, y in loader:
opt.zero_grad()
loss = criterion(net(x), y)
loss.backward() # gradients only reach net.fc
opt.step()
# ---- Phase 2: fine-tune the top block at a TINY learning rate ----
for p in net.layer4.parameters(): # unfreeze last residual block
p.requires_grad = True
opt = torch.optim.Adam([
{"params": net.layer4.parameters(), "lr": 1e-5}, # 100x smaller — don't wash out
{"params": net.fc.parameters(), "lr": 1e-4},
])
# ...continue training with train_one_epoch(); the backbone's early layers
# stay frozen, layer4 specializes, the head keeps learning.
Two details mirror the AVL-style "famous gotchas." First, net.eval() matters: if your backbone has BatchNorm layers and you leave them in train() mode while frozen, their running statistics drift toward your tiny dataset and silently degrade the features. Second, the discriminative learning rates (1e-5 for layer4, 1e-4 for the head) encode the principle directly: the deeper a layer, the less it should move.
Variants worth knowing
Linear probing vs full fine-tuning. Linear probing trains only a single linear layer on frozen features — it's the strictest form of feature extraction and a standard way to measure how good the pretrained representation is. Kumar et al. (2022) showed that fine-tuning can actually underperform linear probing out-of-distribution because it distorts the pretrained features; their fix, "LP-FT," does linear probing first and then fine-tunes, getting the best of both.
Parameter-efficient fine-tuning (PEFT). Instead of updating the backbone, inject small trainable modules. Adapters (Houlsby et al., 2019) add tiny bottleneck layers between frozen blocks. LoRA (Hu et al., 2021) learns a low-rank update ΔW = BA to frozen weight matrices, training <1% of parameters while matching full fine-tuning on many tasks. Prompt/prefix tuning learns input vectors and touches no model weights at all.
Domain adaptation. A flavor of transfer where the task is the same but the data distribution shifts (synthetic → real images, one hospital's scans → another's). Techniques like domain-adversarial training (DANN) align the source and target feature distributions explicitly.
Self-supervised pretraining. The modern source of backbones. Instead of labeled ImageNet, models pretrain on unlabeled data with proxy objectives — masked-token prediction (BERT), contrastive learning (SimCLR, CLIP), next-token prediction (GPT). The resulting features transfer better and need no labels to create.
Knowledge distillation. A related-but-distinct idea: instead of transferring a model's weights, transfer its behavior by training a small "student" to mimic a big "teacher's" output probabilities.
Common bugs and edge cases
- Forgetting to freeze. The number-one beginner mistake: leaving the backbone trainable and using a normal learning rate, so the first batch's gradients destroy the pretrained weights. Symptom: training loss spikes then accuracy is no better than from scratch.
- BatchNorm in train mode while frozen. Frozen weights but live BatchNorm running stats means your tiny dataset's statistics overwrite ImageNet's. Call
.eval()on the backbone, or freeze BN explicitly. - Mismatched preprocessing. The pretrained model expects the exact normalization it was trained with (ImageNet mean/std, specific resize). Feed it raw 0-255 pixels and the features are garbage — a silent accuracy killer with no error message.
- Learning rate too high when fine-tuning. This is catastrophic forgetting in slow motion: the model trains, accuracy looks okay, but you've thrown away generalization. Use 1e-5 range and watch validation, not training, loss.
- Negative transfer. Blindly fine-tuning an ImageNet model on a domain with no visual overlap (spectrograms, tabular features turned into images) can do worse than a small dedicated model. Always benchmark against a from-scratch baseline.
- Data leakage through the backbone. If your target classes overlapped the source's pretraining data, your "few-shot" results are inflated — the model already saw those concepts. Check that your benchmark isn't secretly in ImageNet/LAION.
Frequently asked questions
What is the difference between feature extraction and fine-tuning?
Feature extraction freezes the entire pretrained backbone and trains only a new classifier head on top — the backbone acts as a fixed feature function. Fine-tuning unfreezes some or all backbone layers and updates them too, usually at a much smaller learning rate. Feature extraction is faster and safer on tiny datasets; fine-tuning reaches higher accuracy when you have a few thousand examples or more.
Why freeze the early layers and retrain only the top?
Early layers learn generic features — edges, textures, color blobs in vision; subword statistics and syntax in language — that transfer across almost any task. Late layers learn task-specific combinations that don't transfer. Freezing the generic early layers preserves what's already correct and limits the number of trainable parameters, so a small dataset can't overfit and ruin them.
How much data do you need for transfer learning?
Often 100-1000 labeled examples per class is enough for strong results, versus the millions a from-scratch model needs. A frozen-backbone classifier can give a usable signal from as few as 20-50 examples per class, because only the small head is learning.
What is catastrophic forgetting in transfer learning?
Catastrophic forgetting is when fine-tuning at too high a learning rate overwrites the pretrained weights with noisy gradients from the small new dataset, destroying the general features that made transfer work. The fix is a small learning rate (often 10-100× smaller than pretraining), gradual unfreezing, or freezing early layers entirely.
When does transfer learning hurt instead of help?
Negative transfer happens when the source and target domains are too different — a model pretrained on natural photos transfers poorly to raw radio spectrograms or tabular medical data. If the low-level statistics don't overlap, the frozen features are irrelevant noise and a smaller from-scratch model can beat it.
Is transfer learning the same as a foundation model?
Foundation models like BERT, ResNet, CLIP, and GPT are the pretrained source models that transfer learning adapts. Transfer learning is the technique; the foundation model is the asset you transfer from. Modern practice has shifted from training your own backbone to downloading a foundation model and fine-tuning it.