Machine Learning
Mixture of Experts (MoE)
Hold a trillion parameters, pay for a billion — route each token to a few experts
A Mixture of Experts (MoE) replaces a dense feed-forward layer with many expert subnetworks plus a router that sends each token to just the top-k experts, scaling total parameters far beyond the FLOPs spent per token.
- Total parameters∝ number of experts
- FLOPs per token∝ k (experts chosen)
- Typical routingtop-1 or top-2
- Router costO(d · E) per token
- Memory at inferenceall experts resident
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The intuition: a router and a roomful of specialists
A dense transformer pushes every token through the same giant feed-forward network. Want a smarter model? Make that network wider — but every extra parameter now costs extra multiply-adds on every single token. Compute and parameter count are welded together, and you pay the full price on every word.
Mixture of Experts breaks that weld. Replace the one big feed-forward block with E smaller ones — the experts — and add a tiny router (also called the gating network) in front. For each token, the router scores all E experts, picks the best k (usually 1 or 2), and sends the token only to those. The other E − k experts sit idle for that token. The layer now holds the parameters of all E experts but only spends the compute of k.
The idea goes back to Jacobs, Jordan, Nowlan, and Hinton in 1991 — "Adaptive Mixtures of Local Experts." It sat mostly dormant until Shazeer et al. revived it at scale in 2017 with a 137-billion-parameter sparsely-gated MoE LSTM, and it became mainstream when Google's Switch Transformer (Fedus, Zoph, Shazeer, 2021) simplified routing to top-1 and pushed past a trillion parameters. Today MoE underpins many frontier LLMs precisely because it decouples capacity from cost.
The precise mechanism
Let a token's hidden vector be x ∈ ℝ^d. The router is a single learned matrix W_r ∈ ℝ^{d×E}. It produces one logit per expert:
logits = x · W_r # shape (E,)
gates = softmax(logits) # a probability over experts
top = argtopk(gates, k) # indices of the k largest gates
y = Σ_{i ∈ top} gates[i] · Expert_i(x) # weighted sum of the chosen experts
Each Expert_i is itself a small feed-forward network — typically the same two-layer MLP shape as the dense block it replaced. The token is processed only by the k selected experts, and their outputs are combined weighted by the (renormalized) gate values. Switch Transformer's top-1 special case drops the sum entirely: one expert, one gate scalar.
Complexity. The router costs O(d · E) per token — cheap, since E is small (8–256) and the experts are where the FLOPs live. Each chosen expert with hidden dimension d_ff costs O(d · d_ff), so the layer is O(k · d · d_ff + d · E) per token. Crucially this is independent of E for the expert work — adding experts grows total parameters (O(E · d · d_ff) resident) without growing per-token FLOPs. That is the entire value proposition.
Expert capacity, dropped tokens, and the load-balancing loss
On a GPU you can't have experts process a ragged, variable number of tokens — the math has to fit in fixed-shape dense tensors. So each expert gets a capacity:
capacity = capacity_factor × (tokens_per_batch / num_experts)
With a capacity factor of 1.25 and uniform routing, each expert handles 25% more than its fair share before overflowing. Tokens routed to an expert that's already full are dropped — they skip the MoE layer and flow through the residual connection unchanged. Too low a capacity factor drops many tokens and hurts quality; too high wastes memory and compute on padding.
The deeper problem is routing collapse. Early in training the router has no reason to spread tokens out, and a self-reinforcing loop kicks in: an expert that gets slightly more tokens trains slightly faster, becomes slightly better, and so attracts even more tokens. Left alone, two or three experts hoard everything while the rest never learn — you paid for 64 experts and got 3. The fix is an auxiliary load-balancing loss added to the task loss:
L_balance = α · E · Σ_i f_i · P_i
where f_i is the fraction of tokens dispatched to expert i and P_i is the mean router probability for expert i over the batch. This product is minimized when both are uniform (= 1/E), so the loss nudges the router toward even utilization. The coefficient α is small (≈0.01) — strong enough to balance, weak enough not to override the main objective.
When MoE is worth it — and when it isn't
- Large-scale pretraining on a fixed FLOPs budget. If you have the data and want more capacity per training-FLOP, MoE buys parameters cheaply. This is its home turf.
- You have lots of VRAM but limited compute. MoE trades memory (all experts resident) for FLOPs (only k run). If memory is your bottleneck, MoE makes it worse, not better.
- Throughput-oriented serving with expert parallelism. Sharding experts across devices and batching many tokens amortizes the routing and all-to-all communication.
- Avoid for small models or tight memory. Below a few billion parameters the routing overhead, load-balancing headaches, and dropped tokens rarely pay for themselves; a dense model is simpler and often better per byte of VRAM.
- Avoid when you need single-stream low latency. A batch of one token still has to touch the full all-to-all routing machinery, and the weight bandwidth dominates.
MoE vs dense and other scaling strategies
| Sparse MoE | Dense FFN | Dense MoE (1991) | Ensemble | |
|---|---|---|---|---|
| Experts run per input | top-k (1–2) | the one block | all E | all models |
| Total parameters | ∝ E (large) | baseline | ∝ E | ∝ #models |
| FLOPs per token | ∝ k (small) | baseline | ∝ E (high) | ∝ #models (high) |
| Memory resident | all E experts | baseline | all E | all models |
| Specialization | yes, learned by router | n/a | soft, learned | none — redundant voters |
| Trained end-to-end | yes, with router | yes | yes | usually separately |
| Main failure mode | routing collapse, dropped tokens | compute cost of width | no compute saving | cost multiplies |
| Where used | frontier LLMs | most transformers | historical | kaggle, calibration |
The headline contrast is the FLOPs-vs-parameters split. A dense layer ties them together; sparse MoE decouples them; the old dense MoE and ensembles get the parameters but throw away the compute saving by running everything. MoE's specialization is also genuinely different from ensembling — experts divide the input space rather than redundantly voting on all of it.
What the numbers actually say
- Switch Transformer. Fedus et al. (2021) trained a 1.6-trillion-parameter top-1 MoE (Switch-C) that matched the quality of the dense T5-XXL while reaching the target loss roughly 4× faster in wall-clock time on the same hardware (the smaller Switch-Base/Large variants hit up to 7× over T5-Base/Large) — the parameters were nearly free per FLOP.
- Mixtral 8×7B. 8 experts, top-2 routing. It has about 47B total parameters (experts share attention and embeddings, so it's not 8 × 7B = 56B) but activates only ≈13B per token. It needs ~47B-worth of VRAM yet runs at the speed of a ~13B dense model.
- Router is genuinely cheap. For
d = 4096andE = 64, the router is a 4096×64 matmul ≈ 0.26M FLOPs per token, versus a single expert at~2·d·d_ff ≈ 134MFLOPs. The gate is well under 1% of expert cost. - Capacity factor 1.25 is the common default (Switch). Dropping it toward 1.0 reduces padding waste but loses tokens; raising it to 2.0 drops almost none at roughly double the expert FLOPs and memory.
JavaScript implementation
A minimal top-k MoE layer over a batch of token vectors. No GPU tricks — just the math, so the routing and weighted combination are explicit.
// Each expert: a tiny MLP d -> d_ff -> d (ReLU). Random weights for illustration.
const randMat = (r, c) => Array.from({ length: r }, () =>
Array.from({ length: c }, () => (Math.random() - 0.5) * 0.1));
const matVec = (M, v) => M.map(row => row.reduce((s, w, j) => s + w * v[j], 0));
const relu = v => v.map(x => Math.max(0, x));
function makeExpert(d, dFF) {
const W1 = randMat(dFF, d), W2 = randMat(d, dFF);
return x => matVec(W2, relu(matVec(W1, x))); // d -> d
}
function softmax(v) {
const m = Math.max(...v);
const e = v.map(z => Math.exp(z - m));
const s = e.reduce((a, b) => a + b, 0);
return e.map(z => z / s);
}
function topKIndices(v, k) {
return v.map((val, i) => [val, i])
.sort((a, b) => b[0] - a[0])
.slice(0, k)
.map(p => p[1]);
}
function moeLayer(tokens, { d, dFF, numExperts: E, k }) {
const Wr = randMat(E, d); // router: d -> E
const experts = Array.from({ length: E }, () => makeExpert(d, dFF));
const counts = new Array(E).fill(0); // for load monitoring
const out = tokens.map(x => {
const gates = softmax(matVec(Wr, x)); // probability over experts
const chosen = topKIndices(gates, k);
// renormalize the gates of the chosen experts so they sum to 1
const z = chosen.reduce((s, i) => s + gates[i], 0);
const y = new Array(d).fill(0);
for (const i of chosen) {
counts[i]++;
const w = gates[i] / z;
const ei = experts[i](x);
for (let j = 0; j < d; j++) y[j] += w * ei[j];
}
return y;
});
return { out, counts }; // counts reveals routing balance
}
// 4 tokens, d=8, d_ff=16, 4 experts, top-2
const tokens = Array.from({ length: 4 }, () =>
Array.from({ length: 8 }, () => Math.random()));
const { counts } = moeLayer(tokens, { d: 8, dFF: 16, numExperts: 4, k: 2 });
console.log('tokens per expert:', counts);
Two things to notice. First, only the chosen experts are ever invoked — the loop runs k times, not E times, which is the whole compute saving. Second, counts is your window into routing health: if it reads [7, 1, 0, 0] the router has collapsed and you need the load-balancing loss.
Python / PyTorch implementation
The same layer in PyTorch, plus the Switch-style auxiliary load-balancing loss that real training needs.
import torch, torch.nn as nn, torch.nn.functional as F
class Expert(nn.Module):
def __init__(self, d, d_ff):
super().__init__()
self.net = nn.Sequential(nn.Linear(d, d_ff), nn.ReLU(), nn.Linear(d_ff, d))
def forward(self, x): return self.net(x)
class MoE(nn.Module):
def __init__(self, d, d_ff, num_experts, k=2, alpha=0.01):
super().__init__()
self.k, self.E, self.alpha = k, num_experts, alpha
self.router = nn.Linear(d, num_experts, bias=False)
self.experts = nn.ModuleList(Expert(d, d_ff) for _ in range(num_experts))
def forward(self, x): # x: (N, d) N = batch*seq tokens
logits = self.router(x) # (N, E)
gates = F.softmax(logits, dim=-1) # (N, E)
topv, topi = gates.topk(self.k, dim=-1) # (N, k) each
topv = topv / topv.sum(dim=-1, keepdim=True) # renormalize chosen gates
y = torch.zeros_like(x)
for slot in range(self.k):
idx = topi[:, slot] # which expert each token uses
w = topv[:, slot].unsqueeze(-1) # its gate weight
for e in range(self.E): # gather tokens for expert e
mask = idx == e
if mask.any():
y[mask] += w[mask] * self.experts[e](x[mask])
# Switch-style load-balancing aux loss
f = torch.zeros(self.E, device=x.device) # token fraction per expert
f.scatter_add_(0, topi.reshape(-1),
torch.ones(topi.numel(), device=x.device))
f = f / topi.numel()
P = gates.mean(dim=0) # mean router prob per expert
aux = self.alpha * self.E * torch.dot(f, P)
return y, aux
moe = MoE(d=512, d_ff=2048, num_experts=8, k=2)
x = torch.randn(64, 512) # 64 tokens
y, aux = moe(x)
loss = task_loss(y) + aux # add aux to the real objective during training
The nested loop here is the readable, not the fast, version: production kernels dispatch tokens to experts with a single grouped matmul or an all-to-all across devices. But the contract is identical — gather each expert's tokens, run that expert once on its slice, scatter the weighted results back, and add aux to the loss so the router learns to spread tokens out.
Variants worth knowing
Switch Transformer (top-1). Routes each token to exactly one expert. Simpler, lower communication, and surprisingly strong — it showed that k = 1 is enough if you balance carefully, and pushed MoE past a trillion parameters in 2021.
Expert Choice routing. Flips the assignment: instead of each token picking its top-k experts, each expert picks its top tokens up to capacity. This guarantees perfect load balance by construction and eliminates dropped tokens — at the cost that some tokens may get more experts and others fewer.
BASE / Sinkhorn / optimal-transport routing. Treat assignment as a balanced linear-assignment problem solved with the Sinkhorn algorithm, so balance is enforced exactly rather than nudged by an auxiliary loss.
Soft MoE. Avoids hard discrete routing entirely: each "slot" is a learned weighted average of all tokens, processed by an expert, then scattered back. Fully differentiable, no token dropping, no balancing loss — but it isn't sparse in the same FLOP-saving sense for autoregressive decoding.
Shared + routed experts (DeepSeekMoE style). Keep one or two always-on shared experts for common knowledge plus many fine-grained routed experts for specialization, which improves balance and reduces redundancy across experts.
Common bugs and edge cases
- Forgetting the load-balancing loss. The single most common failure. Without it the router collapses onto a few experts and most of your parameters never train — a quiet bug that just looks like "MoE didn't help."
- Not renormalizing the top-k gates. The chosen gates don't sum to 1 after top-k selection; if you skip the divide-by-sum, the layer's output is scaled down by the dropped probability mass and the residual stream drifts.
- Capacity factor too low. Aggressive capacity silently drops tokens. Watch the drop rate — a few percent is fine, double digits wrecks quality, and the symptom (slightly worse loss) is easy to misattribute.
- Assuming MoE saves memory. It saves FLOPs, not VRAM. All experts must be resident; a "13B-active" MoE still needs the full ~47B parameters in memory. Sizing a serving box by active parameters runs you out of memory.
- Argmax / top-k is non-differentiable. Gradients flow through the gate values (the softmax weights), not through which experts were selected. The router learns from the weighting of chosen experts, plus the auxiliary balance signal — not from a gradient on the discrete pick.
- Ignoring all-to-all communication. With expert parallelism, tokens must be shuffled to the device holding their expert and back. On small batches this network round-trip, not the matmul, is the bottleneck.
Frequently asked questions
How does MoE add parameters without adding compute?
Only the top-k experts a token is routed to actually run. With 64 experts and top-2 routing, the layer holds 64× the parameters of a single expert but does the math of just 2. Total parameters scale with the expert count; FLOPs per token scale with k, not with the count.
What is the difference between sparse MoE and dense MoE?
Dense MoE runs every expert and combines them by a weighted average — full compute, used in the original 1991 formulation. Sparse MoE runs only the top-k experts the router picks, so most experts are skipped per token. Modern LLM MoE layers are sparse; that sparsity is the whole point.
Why do MoE models need a load-balancing loss?
Routers tend to collapse — a few popular experts get most tokens while the rest starve and never train. An auxiliary load-balancing loss penalizes uneven assignment, pushing the router toward roughly equal expert utilization so all the parameters you paid for actually learn.
What is expert capacity and what happens when it overflows?
Each expert has a fixed buffer — capacity = capacity_factor × (tokens / num_experts) — so the work fits in dense tensors on a GPU. Tokens beyond an expert's capacity are dropped: they skip the layer and pass through the residual connection unchanged. A higher capacity factor drops fewer tokens but wastes more memory and compute.
Is MoE the same as ensembling?
No. An ensemble runs every model and averages their outputs to reduce variance. MoE runs only a chosen few experts per input and is trained end-to-end with the router, so the experts specialize and divide the input space rather than redundantly voting on the same thing.
Why is MoE memory-bound rather than compute-bound at inference?
All experts' weights must live in memory even though only k run per token, so a sparse MoE needs as much VRAM as its dense parameter count — a 47B-parameter MoE still needs ~47B parameters resident. You save FLOPs, not memory, which is why MoE serving is dominated by weight bandwidth and expert-parallel sharding.