Machine Learning

Batch Normalization

Re-center and rescale every layer's activations per mini-batch — and train deeper, faster

Batch normalization standardizes each layer's activations using the mean and variance of the current mini-batch, then rescales them with learned γ and β parameters — letting you train deeper networks faster and with higher learning rates.

  • IntroducedIoffe & Szegedy, 2015
  • Normalizes overthe batch dimension
  • Learned params per featureγ (scale), β (shift)
  • Compute costO(N·D) per layer
  • Speed-up on ImageNet≈ 14× fewer steps

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

How batch normalization works

Train a deep network and you hit a wall: the activations partway through drift. As the early layers update their weights, the distribution of numbers feeding the later layers shifts under their feet — sometimes blowing up toward infinity, sometimes collapsing toward zero. Every layer has to keep re-adapting to a moving target, so you're forced to use a tiny learning rate and the whole thing crawls.

Sergey Ioffe and Christian Szegedy proposed a fix in 2015: just standardize the activations as they flow through. For each feature (each neuron, or each channel in a conv net), look at its values across the whole mini-batch, subtract the batch mean, and divide by the batch standard deviation. Now that feature has zero mean and unit variance regardless of what the earlier layers did. The target stops moving.

But forcing every activation to mean 0, variance 1 throws away information — what if a sigmoid downstream actually wants a wide, off-center input? So batch norm adds two learned parameters per feature: a scale γ and a shift β. The final output is γ·x̂ + β. If the network decides normalization was a bad idea for some feature, it can learn γ = σ and β = μ and recover the original activation exactly. Normalization is free to undo — so it never costs representational power, it only ever helps.

The precise mechanism

For a mini-batch of size m and a single feature with values x₁ … xₘ, batch norm computes:

  μ_B  = (1/m) · Σ xᵢ                  # batch mean
  σ²_B = (1/m) · Σ (xᵢ − μ_B)²         # batch variance (biased)
  x̂ᵢ  = (xᵢ − μ_B) / √(σ²_B + ε)       # normalize
  yᵢ  = γ · x̂ᵢ + β                     # scale and shift (learned)

The ε (typically 1e-5) is a tiny constant in the denominator that prevents division by zero when a feature is constant across the batch. Each feature gets its own μ_B, σ²_B, γ, and β; for a layer with D features that's 2D learned parameters and 2D running statistics.

Complexity. The forward pass is two reductions and an elementwise transform over an N×D activation tensor (batch size N, features D), so it's O(N·D) time and O(D) extra parameters — negligible next to the matrix multiplies it sits between. The backward pass is also O(N·D), but it is not a simple elementwise gradient: because μ_B and σ²_B depend on every element of the batch, the gradient of each xᵢ couples to all the others. That coupling is exactly why batch norm is a form of regularization.

The crucial subtlety: at training time the statistics come from the live mini-batch. At inference time you might be classifying a single image, where a "batch mean" is meaningless. So during training batch norm also maintains a running exponential moving average of the mean and variance, and uses those frozen estimates at test time:

  running_μ ← (1 − momentum)·running_μ + momentum·μ_B
  running_σ² ← (1 − momentum)·running_σ² + momentum·σ²_B
  # inference: ŷ = γ·(x − running_μ)/√(running_σ² + ε) + β

When to use it — and the tradeoffs

  • Deep feed-forward and convolutional nets — this is batch norm's home turf. ResNet, Inception, and almost every CNN image classifier from 2015 onward depends on it.
  • When you want a large learning rate. By smoothing the loss landscape, batch norm lets you crank the learning rate 5–30× higher without diverging, which is most of the training-time win.
  • As a mild regularizer. The batch-dependent noise in each normalization acts like a stochastic perturbation, often letting you reduce or drop dropout.

The tradeoffs are real, though:

  • It couples examples in a batch. An image's normalized value depends on the other images it happened to share a batch with. That's fine for classification but breaks things like contrastive learning or any setup where examples should be independent.
  • It's fragile at small batch sizes. With a batch of 2–4 the statistics are too noisy; accuracy degrades sharply. Layer norm or group norm are the fix.
  • It's awkward for recurrent and sequence models. Sequence length varies and statistics shift along the time axis, which is why Transformers and RNNs use layer norm instead.

Batch norm vs the other normalizers

Batch NormLayer NormGroup NormInstance NormRMSNormWeight Norm
Normalizes acrossbatch (per feature)features (per example)feature groups (per example)each channel (per example)features, no mean subtractionthe weight vector, not activations
Depends on batch size?Yes — fragile if smallNoNoNoNoNo
Train ≠ test behavior?Yes (running stats)NoNoNoNoNo
Learned paramsγ, β per featureγ, β per featureγ, β per featureγ, β per channelγ per featureg (magnitude) + direction
Best forCNNs, large batchesTransformers, RNNsvision, tiny batchesstyle transfer, GANsLLMs (cheaper LayerNorm)tasks where BN hurts
Extra inference cost0 (foldable into conv)one reduction per tokenone reduction per groupone reduction per channelone reduction per token0

The headline distinction is which axis the statistics are computed over. Batch norm reduces over the batch, so a single example's output depends on its batchmates and on whether you're training or testing. Layer norm and its relatives reduce over the feature axis within one example — fully self-contained, batch-size-agnostic, identical at train and test time. That independence is exactly why every modern large language model uses layer norm or RMSNorm, not batch norm.

What the numbers actually say

  • ≈ 14× fewer training steps on ImageNet. In the original paper, a batch-normalized Inception matched the baseline's accuracy in about 1/14 of the steps, and went on to beat it. That's the canonical "batch norm makes training faster" number.
  • Learning rate 5–30× higher. Without batch norm, large learning rates make activations explode within a few layers. The paper trained stably at rates that diverged immediately for the un-normalized baseline.
  • Compute overhead under 5% per layer. A 256×1024 activation matrix needs ≈ 0.5M extra FLOPs for the two reductions, versus ≈ 268M FLOPs for the 1024×1024 matmul it follows — well under 1% there, and folded to zero at inference.
  • 2× memory at training time. Backprop through batch norm needs the normalized activations cached for the backward pass, roughly doubling the stored activations for that layer — a real cost on memory-bound training.
  • Small-batch cliff. Going from batch 32 to batch 2 on ImageNet-scale models can cost 5–10 accuracy points with batch norm; group norm closes nearly all of that gap.

JavaScript implementation

A minimal batch norm over a 2-D array of shape [N][D] (N examples, D features), normalizing each feature column independently:

function batchNormForward(X, gamma, beta, eps = 1e-5) {
  const N = X.length, D = X[0].length;
  const mean = new Array(D).fill(0);
  const varc = new Array(D).fill(0);

  // 1. per-feature batch mean
  for (let i = 0; i < N; i++)
    for (let j = 0; j < D; j++) mean[j] += X[i][j] / N;

  // 2. per-feature batch variance (biased — divide by N, not N-1)
  for (let i = 0; i < N; i++)
    for (let j = 0; j < D; j++) {
      const d = X[i][j] - mean[j];
      varc[j] += (d * d) / N;
    }

  // 3. normalize, then scale (gamma) and shift (beta)
  const out = [];
  const xhat = [];                       // cache for the backward pass
  for (let i = 0; i < N; i++) {
    out[i] = []; xhat[i] = [];
    for (let j = 0; j < D; j++) {
      const xh = (X[i][j] - mean[j]) / Math.sqrt(varc[j] + eps);
      xhat[i][j] = xh;
      out[i][j] = gamma[j] * xh + beta[j];
    }
  }
  return { out, cache: { xhat, mean, varc, eps } };
}

Note variance uses the biased estimator (divide by N, not N − 1) — that's the convention in every deep-learning framework for the training-time normalization, and it matters when you compare implementations.

Python implementation — forward and backward

The forward pass is short; the backward pass is the part people get wrong, because the loss flows back through three paths (the normalized value, the mean, and the variance). Here is the full vectorized version with NumPy:

import numpy as np

def batchnorm_forward(x, gamma, beta, eps=1e-5):
    mu  = x.mean(axis=0)                  # (D,)
    var = x.var(axis=0)                   # (D,) biased
    xhat = (x - mu) / np.sqrt(var + eps)  # (N, D)
    out = gamma * xhat + beta
    cache = (xhat, gamma, x - mu, var, eps)
    return out, cache

def batchnorm_backward(dout, cache):
    xhat, gamma, xmu, var, eps = cache
    N = dout.shape[0]
    std_inv = 1.0 / np.sqrt(var + eps)

    # gradients for the learned parameters
    dgamma = np.sum(dout * xhat, axis=0)
    dbeta  = np.sum(dout, axis=0)

    # gradient through the normalization — every xi couples via mu and var
    dxhat = dout * gamma
    dvar  = np.sum(dxhat * xmu, axis=0) * -0.5 * std_inv**3
    dmu   = np.sum(dxhat * -std_inv, axis=0) + dvar * np.mean(-2.0 * xmu, axis=0)
    dx = (dxhat * std_inv) + (dvar * 2.0 * xmu / N) + (dmu / N)
    return dx, dgamma, dbeta

The reason dx has three additive terms is the chain rule fanning out through , through σ²_B, and through μ_B — all three depend on every xᵢ in the batch. In real code you'd let autograd handle this, but writing it once by hand is the clearest way to see why batch norm makes examples interact.

And the inference path, which never touches the batch — it uses the running statistics frozen during training:

def batchnorm_inference(x, gamma, beta, run_mu, run_var, eps=1e-5):
    xhat = (x - run_mu) / np.sqrt(run_var + eps)
    return gamma * xhat + beta            # works for a batch of one

Variants worth knowing

Layer Normalization (2016). Normalizes across the feature dimension within a single example instead of across the batch. Batch-size-agnostic and identical at train and test, which makes it the default for Transformers and RNNs. No running statistics.

Group Normalization (2018). Splits the channels into groups and normalizes within each group, per example. Bridges batch norm and layer norm; the go-to when batches are tiny (object detection, segmentation, video) where batch norm's statistics are too noisy.

Instance Normalization. Group norm with a group of one channel — normalizes each channel of each example independently. Born in style transfer, where you want to strip per-image contrast statistics.

RMSNorm. Drops the mean-subtraction entirely and divides only by the root-mean-square of the features, with a single learned scale. Cheaper than layer norm and now standard in large language models like LLaMA.

Weight Normalization. Reparameterizes the weight vector as magnitude × direction rather than touching activations at all. Sidesteps the batch dependence completely, at the cost of weaker conditioning benefits.

Common bugs and edge cases

  • Forgetting model.eval() at inference. The number-one batch-norm bug. Leave the model in train mode and it normalizes your test batch by its own statistics — accuracy is fine for large eval batches but wildly wrong for a single example, and non-deterministic.
  • Leaving a bias in the layer before batch norm. Batch norm subtracts the mean, so any bias in the preceding linear/conv layer is immediately cancelled — it does nothing but waste parameters and a tiny bit of compute. Set bias=False there.
  • Tiny batch sizes. Below ~8 the statistics get too noisy and training destabilizes. Switch to group norm rather than fighting it.
  • Putting batch norm before the residual add in unusual orders. Placement relative to skip connections and activations matters; copy the proven ordering from ResNet rather than improvising.
  • Mismatched running-stat momentum conventions. PyTorch's momentum is the weight on the new batch (default 0.1); some other frameworks define it as the decay on the old average. Porting weights without checking this silently corrupts the running statistics.
  • Distributed training without synchronized batch norm. Each GPU computes statistics on its own shard, so the effective batch is the per-GPU batch, not the global one. Use SyncBatchNorm when the per-device batch is small.

Frequently asked questions

What are the γ and β parameters in batch normalization for?

After standardizing each activation to zero mean and unit variance, batch norm rescales it as γ·x̂ + β, where γ (scale) and β (shift) are learned per channel. They let the network undo the normalization if that's actually optimal — for example, recover the full range a sigmoid or tanh needs — so normalization never costs the model representational power.

Why does batch normalization behave differently at training time and test time?

During training, batch norm normalizes using the mean and variance of the current mini-batch. At inference you may have a batch of one, so it instead uses running (exponential moving) averages of the mean and variance accumulated during training. Forgetting to switch into eval mode — model.eval() in PyTorch — is the single most common batch-norm bug.

Does batch normalization really reduce internal covariate shift?

That was the original 2015 explanation by Ioffe and Szegedy, but a 2018 MIT paper (Santurkar et al.) showed you can inject noise after batch norm to deliberately worsen covariate shift and still train fine. The modern view is that batch norm mainly smooths the loss landscape — it makes gradients more predictable, which is why you can use much larger learning rates.

Where do you put batch norm — before or after the activation function?

The original paper placed it before the activation (Linear → BatchNorm → ReLU). In practice both orderings work and many modern architectures put it after. Whichever you choose, drop the bias term in the preceding linear/conv layer — batch norm's β already supplies a learnable shift, so the bias is redundant.

Why does batch normalization fail with small batch sizes?

The batch mean and variance are noisy estimates of the true statistics. With a batch of 2–4 (common when an image is huge or memory is tight) that noise swamps the signal and accuracy collapses. Group norm and layer norm fix this by normalizing across feature dimensions instead of the batch, so they're independent of batch size.

Can you fold batch normalization into the previous layer for faster inference?

Yes. At inference batch norm is just an affine transform y = a·x + b with fixed a and b. You can fuse those constants into the weights and bias of the preceding convolution or linear layer, eliminating the batch-norm op entirely. Frameworks call this BN folding; it's standard before deploying a model and costs zero accuracy.