Machine Learning
Federated Learning
Train one model on a billion phones — without their data ever leaving
Federated learning trains one shared model across millions of devices by sending model updates instead of raw data: each device runs local SGD, a server averages the gradients with FedAvg, and the data never leaves the phone.
- IntroducedMcMahan et al., 2016
- Core algorithmFedAvg
- What's uploadedModel deltas, not data
- Per-round cost≈ 2× model size / client
- Central challengeNon-IID client drift
Interactive visualization
Press play, or step through manually. The visualization is yours to drive — try it before reading on.
Watch the 60-second explainer
A condensed visual walkthrough — narrated, captioned, under a minute.
The idea: move the model, not the data
Every machine-learning textbook assumes you can pile your data into one place and train on it. But the most valuable data lives where you can't move it — the photos on a phone, the words a doctor types into a chart, the steering inputs from a car. Privacy law, bandwidth, and plain user expectation all say: that data stays put.
Federated learning inverts the usual flow. Instead of shipping data to the model, you ship the model to the data. A central server holds a shared model. Each round it sends a copy of the model to a sample of participating devices. Every device trains that copy on its own local data for a few steps, then sends back only the change in weights — a vector the same shape as the model. The server averages those changes into a new global model, and the cycle repeats. The raw examples never leave the device.
The term was coined by Brendan McMahan and colleagues at Google in 2016, in the paper that introduced Federated Averaging (FedAvg). The motivating product was Gboard, Android's keyboard: it learns to predict your next word from your typing, but Google never wants to see your keystrokes. Federated learning made that contradiction tractable.
The FedAvg algorithm, precisely
FedAvg is a synchronous round-based protocol. Let wt be the global weights at round t, and let client k hold nk examples out of n total. One round goes:
- Select. The server samples a fraction
Cof available clients — say 100 out of a million currently online. - Broadcast. Each selected client downloads the current global model
wt. - Local train. Client
krunsEepochs of mini-batch SGD on its own data, producing local weightswt+1k. - Aggregate. The server forms the new global model as a data-weighted average of the returned models:
K n_k
w_{t+1} = Σ ─── · w^k_{t+1}
k=1 n
The weighting by nk/n matters: a client that trained on 10,000 examples should count more than one that trained on 50. When E = 1 and the batch is the full local dataset, FedAvg reduces exactly to distributed (synchronous) gradient descent. Increasing E is the whole trick — it lets each client do meaningful work between expensive communication rounds.
The complexity that actually matters is communication, not computation. Local training is "free" — it happens overnight while the phone charges. The scarce resource is rounds. With R rounds and K clients per round, total uploaded volume is O(R · K · |w|) where |w| is the model size — and crucially, it is independent of the dataset size. Compute per client per round is O(E · nk · cost-per-example), the same as ordinary SGD. The art of the field is driving R down.
When to reach for federated learning
- The data is legally or physically immovable. Health records under HIPAA, financial data, anything that would trip GDPR if centralized. Federated learning is a data-minimization design: you never collect what you don't need.
- The data is naturally distributed and huge. Hundreds of millions of phones each holding a little data, where uploading all of it would saturate networks and storage.
- Personalization on top of a shared base is the goal. A global keyboard model that each device then fine-tunes locally.
- Cross-silo collaboration. A handful of hospitals or banks who want a joint model but can't share patient or customer rows with each other.
Conversely, skip it when you can centralize cheaply. Federated learning is strictly harder to train (non-IID drift, stragglers, partial participation) and harder to debug (you can't inspect the data). If a centralized dataset is legal and affordable, centralize — you'll converge faster and more reliably.
Federated learning vs. the alternatives
| Centralized training | Distributed SGD (data-center) | Federated (cross-device) | Federated (cross-silo) | Split learning | |
|---|---|---|---|---|---|
| Raw data location | One server | Sharded across workers you control | Stays on user devices | Stays in each organization | Split: client holds early layers |
| What's communicated | — | Gradients every step | Weight deltas every E epochs | Weight deltas per round | Activations + gradients at cut |
| Participants | 1 | 10s–1000s reliable workers | up to 10⁶–10⁹ unreliable clients | 2–100 reliable orgs | 1 client + server |
| Data distribution | IID by construction | IID (you shuffle) | Heavily non-IID | Non-IID across silos | Varies |
| Availability | Always | Always | Intermittent (charging + Wi-Fi) | Mostly available | Always |
| Privacy posture | None inherent | None inherent | Data-minimizing; +secure agg / DP | Data-minimizing | Raw data hidden, activations leak |
The headline distinction from data-center distributed SGD is who controls the participants. In a data center you shuffle data into IID shards and every worker is fast and online. In cross-device federated learning the clients are someone's phone — non-IID, frequently offline, dropping out mid-round. Every hard part of the field flows from that single fact.
What the numbers actually say
- Communication savings of 10–100×. McMahan's original experiments showed FedAvg reaching a target accuracy on MNIST and a language model in roughly 10–100× fewer communication rounds than per-step distributed SGD, by raising local epochs
Eand batches per round. - Per-round payload ≈ 2× model size. A client downloads the model and uploads one delta. For a 4 MB on-device model that's about 8 MB per participating round — trivial on Wi-Fi, which is why FedAvg restricts training to unmetered networks.
- A single gradient can leak the data. "Deep Leakage from Gradients" (Zhu et al., 2019) reconstructed pixel-accurate training images from one shared gradient. This is why production systems never let the server see an individual update.
- Secure aggregation costs are modest. Bonawitz et al.'s 2017 protocol lets the server learn only the sum of updates, tolerating clients that drop out, with overhead that grows only logarithmically-ish in the cohort size for practical parameters.
- Differential privacy has a measurable accuracy tax. Clipping each update to an L2 norm bound and adding Gaussian noise to the aggregate buys a formal (ε, δ) guarantee — but for tight ε the accuracy hit on next-word prediction is a few points, which production teams budget for explicitly.
JavaScript: one round of FedAvg
The aggregation step is the heart of FedAvg, and it's strikingly simple — a data-weighted average of weight vectors. This standalone simulation makes the data flow concrete: clients train locally, the server only ever touches deltas.
// A "model" is just a flat weight vector here.
const dot = (a, b) => a.reduce((s, x, i) => s + x * b[i], 0);
// One client: run E epochs of SGD on its PRIVATE data, return the delta only.
function clientUpdate(globalW, data, { E = 5, lr = 0.1 } = {}) {
let w = globalW.slice(); // local copy of the model
for (let e = 0; e < E; e++) {
for (const { x, y } of data) { // data never leaves this function
const pred = dot(w, x);
const err = pred - y;
for (let j = 0; j < w.length; j++) w[j] -= lr * err * x[j];
}
}
// Return ONLY the change in weights — not w, and never `data`.
return { delta: w.map((wj, j) => wj - globalW[j]), n: data.length };
}
// The server: data-weighted average of client deltas (the FedAvg sum).
function federatedAveraging(globalW, clients, opts) {
const updates = clients.map(d => clientUpdate(globalW, d, opts));
const total = updates.reduce((s, u) => s + u.n, 0);
const newW = globalW.slice();
for (const { delta, n } of updates) {
const weight = n / total; // n_k / n
for (let j = 0; j < newW.length; j++) newW[j] += weight * delta[j];
}
return newW; // new global model
}
// Three devices, each with its own local data; train for several rounds.
let w = [0, 0, 0];
const devices = [
[{ x: [1, 2, 1], y: 6 }, { x: [2, 1, 1], y: 5 }],
[{ x: [0, 3, 1], y: 7 }, { x: [3, 0, 1], y: 4 }],
[{ x: [1, 1, 1], y: 4 }],
];
for (let round = 0; round < 50; round++) {
w = federatedAveraging(w, devices, { E: 3, lr: 0.02 });
}
console.log(w); // converges toward weights fitting the pooled targets
Two things are worth noticing. First, clientUpdate closes over data and returns only delta — the simulation literally cannot leak rows because nothing carries them out. Second, the server arithmetic is just the weighted sum from the math section; the privacy and robustness machinery wraps around this core, it doesn't change it.
Python: FedAvg with client drift made visible
This version uses NumPy so the weighting and the non-IID problem are explicit. Notice that each client is handed a deliberately skewed slice of the data — the source of all the difficulty.
import numpy as np
def client_update(global_w, X, y, epochs=5, lr=0.05, batch=8):
"""Run local SGD on PRIVATE (X, y); return weight delta + sample count."""
w = global_w.copy()
n = len(y)
for _ in range(epochs):
idx = np.random.permutation(n)
for s in range(0, n, batch):
b = idx[s:s + batch]
grad = X[b].T @ (X[b] @ w - y[b]) / len(b)
w -= lr * grad
return w - global_w, n # delta, n_k (X, y stay here)
def fedavg(global_w, clients, **kw):
"""Server: data-weighted average of client deltas. Sees no raw data."""
deltas, counts = zip(*[client_update(global_w, X, y, **kw)
for (X, y) in clients])
total = sum(counts)
agg = sum((c / total) * d for d, c in zip(deltas, counts))
return global_w + agg
# Build a NON-IID split: each client sees a biased region of feature space.
rng = np.random.default_rng(0)
true_w = np.array([2.0, -1.0, 0.5])
def make_client(center):
X = rng.normal(center, 0.5, size=(40, 3))
y = X @ true_w + rng.normal(0, 0.1, size=40)
return X, y
clients = [make_client(c) for c in (-2.0, 0.0, 3.0)] # skewed per device
w = np.zeros(3)
for r in range(200):
w = fedavg(w, clients, epochs=2, lr=0.02)
print("recovered:", w.round(2), " true:", true_w) # close, despite the skew
Run it and FedAvg recovers true_w closely — but raise the skew between client centers, push epochs high, and you'll watch convergence degrade as the local optima pull apart. That degradation is "client drift," and the variants below exist to fight it.
Variants worth knowing
FedProx (Li et al., 2018). Adds a proximal term (μ/2)·‖w − wt‖² to each client's local objective, penalizing drift away from the global model. It tames stragglers and heterogeneous compute without changing the round structure.
SCAFFOLD (Karimireddy et al., 2020). Maintains per-client and server "control variates" that estimate the drift direction and subtract it from each local gradient. It provably corrects client drift and converges in far fewer rounds on non-IID data, at the cost of doubling the communicated state.
FedAdam / adaptive server optimizers (Reddi et al., 2020). Treats the averaged delta as a pseudo-gradient and applies Adam, Adagrad, or Yogi on the server. A drop-in upgrade that often beats vanilla FedAvg with no client-side change.
Secure Aggregation (Bonawitz et al., 2017). A cryptographic protocol where clients mask their updates with pairwise-cancelling random vectors, so the server recovers only the sum — never any individual update — and the masks still cancel even when some clients drop out.
DP-FedAvg. Clips each client update to a fixed L2 norm and adds calibrated Gaussian noise to the aggregate, yielding a formal differential-privacy guarantee that bounds how much any single user's data can influence the final model.
Personalized / clustered federated learning. Instead of one global model, learn a shared base plus per-client heads, or cluster clients with similar distributions. The acknowledgement that "one model for everyone" is sometimes the wrong target on heavily non-IID data.
Common bugs and edge cases
- Averaging weights instead of deltas, inconsistently. Averaging final weights and averaging deltas are mathematically equivalent only when every client started from the same global model. Mix updates from different global versions (a "stale" client) and the average is meaningless — version-tag every update and reject stale ones.
- Forgetting the
nk/nweighting. A plain unweighted mean lets a client with 3 examples sway the model as much as one with 30,000. On skewed cohorts this visibly hurts accuracy. - Too many local epochs on non-IID data. Large
Esaves communication but amplifies client drift; clients over-fit their local distribution and the average lands in a worse place. There's a sweet spot, and it's data-dependent. - Treating federated learning as encryption. A raw gradient can leak training examples (gradient inversion). Federated learning alone is data-minimization, not privacy — you still need secure aggregation and/or differential privacy for a real guarantee.
- Ignoring stragglers and dropouts. Phones go offline mid-round. A protocol that blocks on all selected clients never finishes; you must aggregate over whoever reports back and tolerate partial participation.
- BatchNorm statistics across clients. Per-client running means and variances don't average sensibly under non-IID data; many systems switch to GroupNorm or keep BatchNorm stats local to avoid corrupting the global model.
Frequently asked questions
What exactly leaves the device in federated learning?
Only the model update — a vector of weight deltas the same size as the model — leaves the device, never the raw training examples. For a 1-million-parameter model that's roughly 4 MB per round, uploaded once and then discarded by the server after aggregation. The photos, keystrokes, or health records that produced the update stay on the phone.
How is FedAvg different from ordinary distributed SGD?
Distributed SGD averages gradients after every single step, so workers must communicate constantly. FedAvg lets each client run E local epochs of SGD before sending anything, then averages the resulting weights. That cuts communication by 10–100×, but because clients drift apart on non-IID data, FedAvg trades fewer rounds for slightly worse per-round convergence.
Does federated learning actually keep my data private?
Not by itself. A raw gradient can leak training examples — gradient-inversion attacks have reconstructed pixel-perfect images from a single update. Real systems add secure aggregation (the server only sees the sum of updates, never an individual one) and differential privacy (clipped, noised updates) to close those gaps. Federated learning is a data-minimization technique, not encryption.
Why is non-IID data the central problem in federated learning?
Each device's data reflects one user, so distributions differ wildly — your phone may only ever see English, mine only Korean. Local models therefore pull toward conflicting optima, and naive averaging can stall or oscillate. This 'client drift' is why FedProx, SCAFFOLD, and adaptive server optimizers exist; they correct the drift that vanilla FedAvg ignores.
How does Google use federated learning in production?
Gboard, Android's keyboard, trains next-word prediction and emoji suggestion this way: phones learn from your typing overnight while charging on Wi-Fi, upload secure-aggregated updates, and never ship keystrokes to Google. The technique debuted in the 2016–2017 papers by McMahan and colleagues and now drives on-device suggestion models across hundreds of millions of phones.
What is the communication cost of a training round?
Each selected client downloads the global model and uploads one update, so per-client cost is about 2× the model size per round. Total cost scales with rounds × clients-per-round, not with the dataset size — which is exactly why compression, quantization, and sparsification of the update matter so much when models reach hundreds of megabytes.