Machine Learning

Recurrent Neural Network (LSTM)

A network with a memory — and the gates that keep it from forgetting

A recurrent neural network processes sequences one step at a time, carrying a hidden state forward so each output depends on the whole past; LSTM gates add a protected cell state that lets it remember context across hundreds of steps.

  • InventedRNN 1986 · LSTM 1997
  • Forward cost (length n)O(n) sequential
  • LSTM gatesforget · input · output
  • Params (hidden H, input D)4·H·(H+D+1)
  • Vanilla RNN memory≈ 10 steps

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

How a recurrent network carries memory

A feed-forward network has no sense of order. Show it the word "bank" and it has no idea whether the previous word was "river" or "savings" — every input is processed in isolation. A recurrent neural network fixes this by adding a loop: it processes a sequence one element at a time, and at every step it feeds its own previous output back into itself.

Concretely, an RNN keeps a hidden state vector h that summarizes everything it has seen so far. At step t it combines the new input x_t with the old state h_{t-1}:

h_t = tanh(W_x · x_t  +  W_h · h_{t-1}  +  b)
y_t = W_y · h_t  +  b_y

The crucial detail is that W_x, W_h, and W_y are the same weights reused at every time step. The network doesn't grow with the sequence; it loops. "Unrolling" that loop for a length-n sequence gives you a chain n copies deep, and you train it with ordinary backpropagation applied through that chain — a procedure called backpropagation through time (BPTT).

That weight-sharing is also the RNN's downfall. Because the gradient that reaches step 1 from step 50 is a product of 49 copies of the same recurrent Jacobian, it either explodes or vanishes geometrically. The LSTM was designed to fix exactly this.

The LSTM: a protected memory and three gates

The Long Short-Term Memory cell, introduced by Hochreiter and Schmidhuber in 1997, adds a second state vector — the cell state c_t — that runs straight through the chain like a conveyor belt. The key idea: the cell state is modified only by addition and elementwise multiplication, never squashed by a tanh, so gradients can ride it back through time without shrinking.

Three gates — each a tiny sigmoid layer outputting numbers in [0, 1] — control what flows onto and off of the belt. Let [h_{t-1}, x_t] be the previous hidden state concatenated with the current input:

f_t = σ(W_f · [h_{t-1}, x_t] + b_f)      forget gate  — what to erase
i_t = σ(W_i · [h_{t-1}, x_t] + b_i)      input  gate  — how much to write
g_t = tanh(W_g · [h_{t-1}, x_t] + b_g)   candidate    — what to write
o_t = σ(W_o · [h_{t-1}, x_t] + b_o)      output gate  — what to expose

c_t = f_t ⊙ c_{t-1}  +  i_t ⊙ g_t        update the memory belt
h_t = o_t ⊙ tanh(c_t)                    new hidden state

Read it as a sentence: forget some of the old memory, write a gated slice of the new candidate, then output a filtered view of the result. The forget gate is the piece a plain RNN never had — it lets the network actively decide to keep a fact ("the subject of this sentence is plural") across dozens of intervening words, or to dump it the moment a new clause starts.

When to reach for an RNN or LSTM

  • Streaming / online inference — when tokens arrive one at a time and you need bounded, constant memory per step (speech transcription, live sensor data, keyboard prediction on a phone).
  • Very long sequences — attention is O(n²) in sequence length; an LSTM is O(n). For sequences of tens of thousands of steps (genomics, long audio), the linear cost can win.
  • Small data — a 1M-parameter LSTM with strong inductive bias often beats an over-parameterized transformer when you have thousands, not billions, of examples.
  • Time-series forecasting — demand, energy load, and financial series with clear temporal structure remain a strong LSTM use case.

If you're modeling natural language at scale, have abundant data, and can afford parallel training on GPUs, a transformer almost always wins. Reach for the LSTM when the sequence is long, the data is small, or the inference is strictly streaming.

RNN vs LSTM vs GRU vs Transformer

Vanilla RNNLSTMGRUTransformer
Effective memory≈ 10 steps100s of steps100s of stepsfull window
Gradient stabilityvanishes/explodesstable (additive cell)stablestable
Weight matrices / cell143QKV + FFN per layer
Training parallelismsequentialsequentialsequentialfully parallel
Cost in sequence lengthO(n)O(n)O(n)O(n²) attention
Inference memoryO(1) stateO(1) stateO(1) stateO(n) KV cache
Best atteaching, short signalslong streaming seriesfast, compact sequence modelslarge-scale NLP

The headline trade-off: recurrent models pay O(n) sequential cost but use O(1) memory at inference and handle arbitrarily long streams; transformers pay O(n²) attention but train in parallel and read any position in one hop. LSTM and GRU sit between vanilla RNNs and transformers — they solve the gradient problem without giving up the streaming, constant-memory property.

What the numbers actually say

  • Vanishing gradient is geometric. If the recurrent Jacobian's largest singular value is 0.9, the gradient after 50 steps is scaled by 0.9⁵⁰ ≈ 0.005 — a 200× attenuation. At 0.5 it's 0.5⁵⁰ ≈ 10⁻¹⁵, effectively zero. This is why a vanilla RNN's useful context is ~10 steps.
  • The LSTM's additive path has derivative ≈ 1. Because ∂c_t/∂c_{t-1} = f_t and the forget gate sits near 1 when memory is wanted, the gradient is multiplied by ≈1 per step instead of by a squashing factor — that's the whole trick.
  • Parameter count. An LSTM layer has 4 · H · (H + D + 1) parameters for hidden size H and input size D — 4× a vanilla RNN's H · (H + D + 1). A GRU has 3×. For H = D = 512 that's about 2.1M parameters per LSTM layer vs 1.6M per GRU layer.
  • Gradient clipping is standard. Exploding gradients are tamed by clipping the global norm to a threshold (commonly 1–5); without it, a single bad batch can send weights to NaN.
  • Truncated BPTT. You rarely backprop through the full sequence — you chop it into windows of, say, 35–200 steps. Full BPTT over a 10,000-step sequence would need 10,000× the activation memory.

JavaScript implementation

A single LSTM cell forward pass. Vectors are plain arrays; this is for clarity, not speed — a real implementation uses matrix libraries and batches.

const sigmoid = x => 1 / (1 + Math.exp(-x));
const dot = (W, v) => W.map(row => row.reduce((s, w, j) => s + w * v[j], 0));
const add = (a, b) => a.map((x, i) => x + b[i]);
const had = (a, b) => a.map((x, i) => x * b[i]);   // elementwise (Hadamard)

// W* are [H x (H+D)] matrices, b* are length-H bias vectors.
function lstmStep(xt, hPrev, cPrev, W, b) {
  const z = [...hPrev, ...xt];                 // concat: length H + D

  const f = add(dot(W.f, z), b.f).map(sigmoid);     // forget gate
  const i = add(dot(W.i, z), b.i).map(sigmoid);     // input  gate
  const g = add(dot(W.g, z), b.g).map(Math.tanh);   // candidate
  const o = add(dot(W.o, z), b.o).map(sigmoid);     // output gate

  const c = add(had(f, cPrev), had(i, g));          // memory belt: erase + write
  const h = had(o, c.map(Math.tanh));               // exposed hidden state
  return { h, c };
}

// Run a whole sequence, carrying (h, c) forward.
function runLSTM(sequence, W, b, H) {
  let h = new Array(H).fill(0);
  let c = new Array(H).fill(0);
  for (const xt of sequence) ({ h, c } = lstmStep(xt, h, c, W, b));
  return h;   // final summary of the sequence
}

Note the two things that make it an LSTM and not an RNN: the cell state c is carried separately from h, and it's updated by f ⊙ c + i ⊙ g — pure addition and multiplication, no tanh wrapping the recurrence. That additive belt is what keeps gradients alive.

Python implementation

The same cell in NumPy, plus the famous trick learners search for — generating text one character at a time by sampling from the output and feeding it back in.

import numpy as np

def sigmoid(x): return 1.0 / (1.0 + np.exp(-x))

class LSTMCell:
    def __init__(self, input_dim, hidden_dim):
        H, D = hidden_dim, input_dim
        k = 1.0 / np.sqrt(H)
        # One stacked matrix for all 4 gates: shape (4H, H + D)
        self.W = np.random.uniform(-k, k, (4 * H, H + D))
        self.b = np.zeros(4 * H)
        self.b[:H] = 1.0          # forget-gate bias = 1: remember by default
        self.H = H

    def step(self, x, h, c):
        z = np.concatenate([h, x])
        gates = self.W @ z + self.b
        H = self.H
        f = sigmoid(gates[0:H])         # forget
        i = sigmoid(gates[H:2*H])       # input
        g = np.tanh(gates[2*H:3*H])     # candidate
        o = sigmoid(gates[3*H:4*H])     # output
        c = f * c + i * g               # additive memory update
        h = o * np.tanh(c)
        return h, c

# Famous problem: character-level text generation by feeding output back in.
def sample(cell, W_out, b_out, seed_onehot, n, vocab):
    h = np.zeros(cell.H); c = np.zeros(cell.H)
    x = seed_onehot; out = []
    for _ in range(n):
        h, c = cell.step(x, h, c)
        logits = W_out @ h + b_out
        p = np.exp(logits - logits.max()); p /= p.sum()
        idx = np.random.choice(len(p), p=p)   # sample, don't argmax
        out.append(vocab[idx])
        x = np.eye(len(vocab))[idx]           # feed the sampled char back in
    return "".join(out)

The forget-gate bias is set to 1 on line 12 deliberately — see the FAQ for why that one line matters. The sample function is the heart of Karpathy's "unreasonably effective" char-RNN: predict a distribution, sample one token, feed it back, repeat.

Variants worth knowing

GRU (Gated Recurrent Unit). Cho et al., 2014. Merges the forget and input gates into a single update gate and drops the separate cell state, exposing the memory directly as the hidden state. Three weight matrices instead of four — ~25% fewer parameters, faster, and competitive with LSTM on most tasks.

Bidirectional RNN/LSTM. Run one LSTM left-to-right and another right-to-left, then concatenate their hidden states. Each position sees both past and future context. Standard in tasks where the whole sequence is available up front (named-entity recognition, speech recognition) — but impossible for streaming, since the backward pass needs the end of the sequence.

Stacked / deep LSTM. Feed one LSTM layer's hidden states as inputs to a second LSTM layer. Two-to-four layers is typical; depth captures hierarchy (phonemes → words → phrases) the way CNN depth captures visual hierarchy.

Peephole LSTM. Lets the gates see the cell state directly, not just the hidden state. Helps with precise timing tasks (counting, rhythm) but rarely used today.

Seq2seq with attention. An encoder LSTM compresses the input into a state; a decoder LSTM generates the output. Attention — added by Bahdanau et al. in 2014 — let the decoder look back at every encoder state, and that mechanism, freed from recurrence entirely, became the transformer.

Common bugs and edge cases

  • Forgetting to clip gradients. RNNs explode as easily as they vanish. Clip the global gradient norm (1–5 is typical); without it, training silently diverges to NaN after a few hundred steps.
  • Not resetting state between sequences. Carrying the hidden state from one independent sample into the next leaks information across the batch boundary and corrupts training. Zero out h and c at each new sequence (unless you deliberately want stateful, continuous training).
  • Forget-gate bias at 0. Initialized to zero, the gate opens to 0.5, so the cell halves its memory every step and the network starts amnesiac. Initialize the forget-gate bias to 1.
  • argmax instead of sampling when generating. Always taking the most-likely token collapses generation into repetitive loops. Sample from the distribution (optionally with a temperature) to get varied, coherent output.
  • Padding without masking. When you pad variable-length sequences to a fixed length for batching, the network will happily learn from the padding tokens unless you mask them out of the loss and (ideally) the recurrence.
  • Using an LSTM where a transformer fits. If you have lots of data, GPUs, and a fixed context window, the LSTM's sequential training is the bottleneck — a transformer trains far faster on the same hardware.

Frequently asked questions

Why do plain RNNs forget long-range context?

Backpropagation through time multiplies the same recurrent weight at every step. If its effective gradient is below 1 the signal shrinks geometrically — after 50 steps a factor of 0.9 leaves 0.9^50 ≈ 0.005 of the gradient, so early inputs barely move the weights. That is the vanishing gradient problem, and it caps a vanilla RNN's useful memory at roughly 10 steps.

How does an LSTM remember things a plain RNN can't?

An LSTM adds a cell state — a memory conveyor belt that is changed only by addition and elementwise multiplication, never by a squashing nonlinearity. The forget gate decides what to erase, the input gate decides what to write, and because the path from cell state at step t to step t+1 has a derivative near 1, gradients flow back hundreds of steps without vanishing.

What are the three gates in an LSTM?

Forget gate (sigmoid: what fraction of each cell value to keep), input gate (sigmoid: what fraction of the new candidate to write) paired with a tanh candidate vector, and output gate (sigmoid: what fraction of the cell to expose as the hidden state). Each gate is its own small linear layer over the concatenation of the previous hidden state and the current input.

What's the difference between an LSTM and a GRU?

A GRU merges the forget and input gates into one update gate and drops the separate cell state, so it has 3 weight matrices instead of 4 — about 25% fewer parameters and faster to train. LSTMs sometimes edge it out on tasks needing very long, precisely-controlled memory, but on most benchmarks the two are within a point of each other.

Have transformers made LSTMs obsolete?

For large-scale language modeling, yes — transformers process a whole sequence in parallel and attend to any position in O(1) hops, while an LSTM is inherently sequential. But LSTMs still win when sequences are extremely long (attention is O(n²) in length), when you stream one token at a time with bounded memory, or when you train on small data where a 1M-parameter LSTM beats an over-parameterized transformer.

Why initialize the LSTM forget-gate bias to 1?

A forget-gate bias of 1 makes the sigmoid open to about 0.73 at the start, so the cell remembers by default and the network can learn what to forget rather than starting amnesiac. Jozefowicz et al. (2015) found this single trick consistently improved training, and most frameworks now default to it.