Machine Learning
Group Normalization
Normalize per sample, not per batch — so a batch of one trains just as well as a batch of 256
Group Normalization splits a layer's channels into fixed groups and normalizes each group per sample, so it computes the same statistics whether the batch holds 32 images or one — unlike Batch Norm, its accuracy doesn't collapse at tiny batch sizes.
- IntroducedWu & He, 2018
- Normalizes overchannels-in-group × H × W
- Batch dependenceNone
- Default groupsG = 32
- Learnable params2C (γ, β)
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The intuition: normalize across the wrong axis and the batch breaks you
Every deep network has a recurring problem: the distribution of activations drifts as the weights update, and large or small activations make the next layer's gradients unstable. Normalization layers fix this by re-centering activations to roughly zero mean and unit variance before the next operation. The only real question is which numbers do you average over.
Picture the activation tensor inside a convolutional layer. It has four axes: N samples in the batch, C channels, and a spatial grid of height H by width W. Batch Norm averages over N × H × W — across the whole batch, per channel. That coupling to the batch is the catch: the statistics are only as stable as the number of samples you have. Train an object detector on two 800×800 images per GPU and the per-channel mean is estimated from just two samples — noisy, and worse, it changes meaning between training (batch statistics) and inference (frozen running averages).
Group Normalization, introduced by Yuxin Wu and Kaiming He at FAIR in 2018, simply moves the averaging axis off the batch. It splits the C channels into G groups and, for each sample independently, normalizes over the channels-in-a-group together with all spatial positions. The batch axis never enters the calculation. So whether you feed it one image or a thousand, every sample is normalized by its own statistics — and the train/test mismatch disappears entirely.
The precise mechanism and math
Let the input be x with shape (N, C, H, W). Reshape the channel axis into G groups of C/G channels each — so conceptually the tensor becomes (N, G, C/G, H, W). For each sample n and each group g, compute the mean and variance over the set S = all C/G channels in that group times all H·W positions:
μ(n,g) = (1/m) · Σ_{i∈S} x_i
σ²(n,g) = (1/m) · Σ_{i∈S} (x_i − μ)² where m = (C/G)·H·W
x̂_i = (x_i − μ(n,g)) / sqrt(σ²(n,g) + ε)
y_i = γ_c · x̂_i + β_c (γ, β are per-channel)
The normalization is computed independently per (sample, group) pair, then a learnable per-channel scale γ and shift β are applied — exactly as in Batch Norm — so the network can undo the normalization if it needs to. The constant ε (typically 1e-5) guards against division by zero.
The whole forward pass is a couple of reductions and an elementwise affine map, so the cost is O(N·C·H·W) — linear in the tensor size, the same order as Batch Norm or a single elementwise op. There are 2C learnable parameters and, crucially, zero running-statistics buffers: because the statistics come from the input at every step, training and inference run the identical code path.
When to reach for Group Norm
- Detection and segmentation. Mask R-CNN, Faster R-CNN and friends process one or two high-resolution images per GPU. Batch Norm's statistics are garbage at that batch size; Group Norm is unaffected. This is the workload it was designed for.
- Video and 3D models. A single video clip eats GPU memory, forcing tiny batches. Group Norm sidesteps the batch entirely.
- Anything where train and test batch sizes differ. Group Norm uses the same formula in both modes, so there is no running-average to go stale.
- Memory-constrained or distributed training where you can't afford SyncBatchNorm's cross-device all-reduce of statistics on every step.
- Diffusion U-Nets. Group Norm (often combined with a SiLU activation) is the standard normalization inside the residual blocks of modern image-generation networks, where batches are small and resolution is high.
If you are doing plain large-batch image classification with 32+ images per GPU, Batch Norm is usually the slightly better and cheaper choice — see the comparison below.
Group Norm vs the other normalization layers
| Group Norm | Batch Norm | Layer Norm | Instance Norm | |
|---|---|---|---|---|
| Averages over | (C/G)·H·W per sample | N·H·W per channel | C·H·W per sample | H·W per sample, per channel |
| Depends on batch size | No | Yes (needs ≳8–16) | No | No |
| Train = test code path | Yes | No (running stats) | Yes | Yes |
| Running-stat buffers | None | 2C floats | None | None |
| Foldable into conv at test | No | Yes (free at inference) | No | No |
| Best on (vision) | Detection, segmentation, small batch | Large-batch classification | Rarely used in conv nets | Style transfer |
| Best on (sequence) | — | — | Transformers, RNNs | — |
| Special case of GN | — | — | G = 1 | G = C |
The unifying view is the punchline of the Group Norm paper: Layer Norm and Instance Norm are the two endpoints of Group Norm. Set G = 1 and every channel goes into one group — that's Layer Norm. Set G = C and each channel is its own group — that's Instance Norm. Group Norm picks an intermediate G (32 by default), which keeps related feature channels grouped without forcing the network to share one statistic across hundreds of unrelated channels.
What the numbers actually say
- The headline result. On ResNet-50 / ImageNet with a batch of 32, Batch Norm reaches about 23.6% error and Group Norm about 24.1% — Batch Norm wins by ~0.5%. But shrink the batch to 2 images and Batch Norm's error explodes to roughly 34% while Group Norm stays near 24%. That ~10-point gap at batch-2 is the entire reason Group Norm exists.
- Group count is forgiving. Sweeping
Gfrom 16 to 64 on ResNet-50 moved error by about 0.5% (24.1–24.6%). The extremes are clearly worse:G = 1(Layer Norm) lands at 25.3% — about 1.2% behindG = 32— andG = C(Instance Norm) collapses to 28.4%, over 4% behind. - Detection gains are large. Swapping frozen Batch Norm for Group Norm in Mask R-CNN's backbone and heads improved box AP by roughly 1.5–2 points on COCO, because the heads could finally be trained instead of frozen.
- The inference cost. Because Group Norm can't be folded into the preceding convolution, it runs as a live reduction at test time — a few percent of layer latency that Batch Norm avoids entirely. On a throughput-bound classification deployment that is real money; on a detector running one image at a time it's noise.
JavaScript implementation
A direct, readable Group Norm forward pass over a single sample's C × H × W tensor stored as a flat array. (Batched code just wraps this in a loop over N.)
// x: Float32Array of length C*H*W, channel-major (c*H*W + h*W + w)
// gamma, beta: Float32Array of length C (per-channel affine)
function groupNorm(x, C, H, W, G, gamma, beta, eps = 1e-5) {
if (C % G !== 0) throw new Error(`C=${C} not divisible by G=${G}`);
const cpg = C / G; // channels per group
const spatial = H * W;
const m = cpg * spatial; // elements normalized together
const y = new Float32Array(x.length);
for (let g = 0; g < G; g++) {
const c0 = g * cpg; // first channel in this group
// --- pass 1: mean and variance over the whole group ---
let sum = 0, sumSq = 0;
for (let c = c0; c < c0 + cpg; c++) {
const base = c * spatial;
for (let i = 0; i < spatial; i++) {
const v = x[base + i];
sum += v; sumSq += v * v;
}
}
const mean = sum / m;
const variance = sumSq / m - mean * mean; // E[x²] − E[x]²
const invStd = 1 / Math.sqrt(variance + eps);
// --- pass 2: normalize, then per-channel affine ---
for (let c = c0; c < c0 + cpg; c++) {
const base = c * spatial, gc = gamma[c], bc = beta[c];
for (let i = 0; i < spatial; i++) {
y[base + i] = (x[base + i] - mean) * invStd * gc + bc;
}
}
}
return y;
}
Two details that matter. First, γ and β are indexed by the absolute channel c, not by the group — the affine is per-channel even though the statistics are per-group. Second, the one-pass variance via E[x²] − E[x]² is fine for fp32 here but is numerically fragile; production kernels use Welford's algorithm to avoid catastrophic cancellation.
Python / NumPy implementation
The same computation, vectorized — and this is essentially what torch.nn.GroupNorm does under the hood.
import numpy as np
def group_norm(x, gamma, beta, G, eps=1e-5):
# x: (N, C, H, W); gamma, beta: (C,)
N, C, H, W = x.shape
assert C % G == 0, f"C={C} not divisible by G={G}"
# 1. reshape so each group is a contiguous block of channels
xg = x.reshape(N, G, C // G, H, W)
# 2. mean/var over (channels-in-group, H, W) — NOT over N
mean = xg.mean(axis=(2, 3, 4), keepdims=True)
var = xg.var(axis=(2, 3, 4), keepdims=True)
# 3. normalize, reshape back to (N, C, H, W)
xg = (xg - mean) / np.sqrt(var + eps)
x_hat = xg.reshape(N, C, H, W)
# 4. per-channel affine (broadcast gamma/beta over N, H, W)
return gamma[None, :, None, None] * x_hat + beta[None, :, None, None]
# PyTorch one-liner equivalent:
# norm = torch.nn.GroupNorm(num_groups=32, num_channels=256)
# y = norm(x) # x: (N, 256, H, W)
# Layer Norm == GroupNorm(num_groups=1, num_channels=C)
# Instance Norm== GroupNorm(num_groups=C, num_channels=C)
Notice the axis tuple (2, 3, 4) deliberately excludes axis 0, the batch. That single choice — averaging over everything but the batch, within each group — is the entire idea of Group Norm.
Variants and where the family lives
Layer Normalization (G = 1). Pools all channels into one group. Dominant in Transformers and RNNs, where the channel/feature axis is the meaningful one and there's no spatial grid to exploit. The 2016 precursor to Group Norm.
Instance Normalization (G = C). Normalizes each channel of each sample on its own. The workhorse of neural style transfer, where you want to wipe out per-image contrast/brightness statistics so style and content separate cleanly.
Switchable Normalization. Learns a softmax-weighted mixture of Batch, Layer, and Instance statistics per layer, letting the network choose. More flexible, but heavier and rarely used in practice now.
Filter Response Normalization (FRN). A 2019 batch-independent alternative that drops the mean-subtraction entirely and normalizes by the mean-square, paired with a learnable thresholded activation (TLU). Competitive with Group Norm at batch-1, and avoids GN's slight degradation at very small group sizes.
Weight Standardization. Not a feature-normalization at all — it standardizes the convolution weights instead. Pairs extremely well with Group Norm: the GN+WS combination closes most of the remaining gap to large-batch Batch Norm on ImageNet and is used in BiT and other transfer-learning backbones.
Common bugs and edge cases
- C not divisible by G. The single most common error.
GroupNorm(32, 100)throws because 100 isn't a multiple of 32. Either pick aGthat divides the channel count, or switch to a fixed channels-per-group scheme and deriveG = C / cpg. - Expecting Batch-Norm-style regularization. Batch Norm's batch noise acts as a mild regularizer; Group Norm has none of that, so you may need slightly stronger explicit regularization (dropout, weight decay) to match generalization on large-batch tasks.
- Forgetting that γ/β are per-channel. A frequent reimplementation bug is applying one scale per group instead of per channel. The statistics are per-group; the affine is per-channel. Mixing them up quietly hurts accuracy.
- Numerically unstable variance. The
E[x²] − E[x]²shortcut can cancel catastrophically in fp16. Use a two-pass or Welford computation, or compute statistics in fp32 even when the tensor is fp16. - Assuming it folds into conv at inference. Unlike Batch Norm, Group Norm can't be fused into the preceding convolution, because its statistics depend on the live input. If you profiled a fused-BN model and swapped in GN, expect the normalization to reappear as a measurable inference cost.
- Using G = C/G = 1 by accident. If a layer has very few channels and you keep G = 32, you can end up with one channel per group (Instance Norm) without realizing it, which is usually worse for classification.
Frequently asked questions
Why does Group Normalization work at a batch size of one?
Group Norm computes its mean and variance over the channels and spatial positions of a single sample — never across the batch. With one image or a thousand, each sample sees identical statistics, so there is no small-batch noise and no train/test mismatch. Batch Norm averages over the batch dimension, so its estimates get noisy and unreliable below roughly 8–16 samples per device.
How is Group Normalization different from Layer Normalization?
Both normalize per sample, never over the batch. Layer Norm pools all channels into one group; Group Norm splits the channels into G groups and normalizes each independently. Layer Norm is the special case G = 1, and Instance Norm is the special case G = C (one channel per group). The intermediate G — typically 32 — is what makes Group Norm work well for convolutional vision models.
How many groups should I use in Group Normalization?
The original paper used G = 32 as the default and found 16–64 all worked within about 0.5% accuracy on ResNet-50/ImageNet. The hard constraint is that the channel count C must be divisible by G. For a layer with 64 channels, G = 32 puts 2 channels per group; for 256 channels it puts 8. If a layer has an awkward channel count, fall back to a fixed channels-per-group instead of fixed G.
Does Group Normalization have learnable parameters?
Yes. Like Batch Norm, every Group Norm layer learns a per-channel scale gamma and shift beta — 2C parameters in total. The normalization removes the mean and variance, then gamma and beta let the network restore whatever scale and offset it actually needs. Unlike Batch Norm, there are no running mean/variance buffers, because the statistics are recomputed from the input at both train and test time.
When does Batch Normalization still beat Group Normalization?
On large-batch image classification — ImageNet with 32+ images per GPU — Batch Norm is usually about 0.5–1% more accurate, because averaging over the batch acts as a mild regularizer and it folds into the preceding convolution at inference for zero extra cost. Group Norm wins precisely where Batch Norm breaks: detection and segmentation (batch of 1–2 high-res images), video, and any transfer-learning setting with tiny effective batches.
Can Group Normalization be fused into the convolution like Batch Norm?
No, and this is its main inference cost. Batch Norm at test time is a fixed affine transform with frozen running statistics, so it folds into the preceding conv's weights for free. Group Norm recomputes mean and variance from each input, so it must run as a live reduction at inference — a few percent of layer latency that can't be optimized away by fusion.