Machine Learning

Vision Transformers (ViT)

Chop an image into patches, call them words, and let attention do the rest

A Vision Transformer (ViT) splits an image into fixed-size patches, linearly embeds each patch as a token, and feeds the sequence through a standard Transformer encoder — letting global self-attention replace the local receptive fields of convolutions.

  • IntroducedDosovitskiy et al., 2020
  • Tokens (224px, P=16)196 + 1 [CLS]
  • Self-attention costO(N²·d)
  • Receptive field at layer 1whole image
  • Sweet spotlarge-scale pre-training

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

The one idea: an image is a sentence of patches

Transformers conquered language by treating a sentence as a sequence of tokens and letting every token attend to every other one. For a decade, images were the holdout — convolutional neural networks owned vision because convolutions encode exactly the right prior: pixels near each other are related, and a feature is the same feature wherever it appears. The 2020 paper "An Image Is Worth 16×16 Words" by Dosovitskiy and colleagues at Google asked a deceptively simple question: what if we just chopped the image into 16×16 squares, called each square a word, and ran the language Transformer unchanged?

That is the entire trick. Take a 224×224 RGB image, cut it into a 14×14 grid of non-overlapping 16×16 patches, and you have 196 little tiles. Flatten each tile into a vector, run it through one learned linear layer, and now each patch is a token — a point in embedding space, indistinguishable in type from a word embedding. Prepend a special learnable [CLS] token, add positional embeddings so the model knows which tile came from where, and feed the 197-token sequence into a vanilla Transformer encoder. The output vector at the [CLS] position is your image representation. No convolutions in the body of the network at all.

The radical consequence: at the very first layer, every patch can attend to every other patch. A CNN at layer one sees a 3×3 neighborhood; a ViT at layer one already has a global receptive field. The model can relate the dog's ear in the top-left to its tail in the bottom-right in a single attention step — no stack of convolutions and pooling required to grow the receptive field.

The forward pass, step by step

Let the input image have height H, width W, and C channels, with patch size P. The pipeline is five stages:

  1. Patchify. Reshape the image into N = HW / P² patches, each a flat vector of length P²·C. For 224×224×3 with P=16, that's N=196 patches of length 768.
  2. Linear projection (patch embedding). Multiply each flattened patch by a learned matrix E ∈ ℝ^{(P²·C) × D} to get a D-dimensional token. In ViT-Base, D=768 — so the projection happens to be square, but in general D is whatever embedding width you choose.
  3. Prepend [CLS] and add positions. Stick a learnable class token at the front (now N+1 = 197 tokens) and add a learnable positional embedding to every token. Without positions, self-attention is permutation-invariant and the grid structure is lost.
  4. Transformer encoder. Run the sequence through L identical blocks. Each block is pre-norm: x = x + MSA(LN(x)) then x = x + MLP(LN(x)), where MSA is multi-head self-attention and MLP is a two-layer feed-forward net (usually 4× width with GELU).
  5. Classify. Take the final [CLS] vector, push it through a linear head (a single hidden layer at pre-training time, just a linear layer when fine-tuning), and softmax to class probabilities.

The mathematical heart is scaled dot-product attention. Pack the N+1 tokens into a matrix and project them three ways into queries Q, keys K, and values V, each of dimension d_k:

Attention(Q, K, V) = softmax( Q·Kᵀ / √d_k ) · V

The Q·Kᵀ term is an (N+1)×(N+1) matrix — every token scoring its relevance to every other token. The √d_k divisor keeps the dot products from saturating the softmax. Multi-head attention runs h of these in parallel on lower-dimensional slices and concatenates the results, letting different heads specialize (some attend locally, some globally).

Where the time and memory go

For a sequence of N tokens of embedding dimension D, one self-attention layer costs:

  • Time: O(N²·D) — the Q·Kᵀ product is N×N entries, each a D-dimensional dot product, and the weighted sum over V is another N×N×D pass.
  • Memory: O(N² + N·D) — you must materialize the N×N attention matrix (unless you use FlashAttention-style fused kernels that avoid storing it).
  • The MLP block is O(N·D²), linear in N. So at the patch counts ViT uses, the per-layer cost is often dominated by the MLP, not attention — a fact people forget when they blame attention for everything.

The quadratic-in-N term is the structural enemy. With P=16 a 224×224 image is N=196, perfectly affordable. Drop to P=8 and N jumps to 784, so the N² attention matrix grows 16×. Push to a 1024×1024 medical scan at P=16 and you have N≈4,096 tokens, an attention matrix of ~16.8 million entries per head per layer. This is why high-resolution and dense-prediction ViTs abandon global attention for windowed schemes (Swin) or linear-attention approximations.

When to reach for a ViT — and when not to

  • You have a lot of data (or a big pre-trained checkpoint). ViTs shine after pre-training on 100M+ images or when you fine-tune a public checkpoint. From scratch on a small dataset, they underperform CNNs badly.
  • You need global context early. Scene understanding, relating distant objects, fine-grained classification where the discriminative cue could be anywhere — attention's day-one global receptive field helps.
  • You want one architecture across modalities. The same Transformer backbone handles text, image patches, and audio frames, which is why multimodal models (CLIP, Flamingo, and their descendants) lean on ViT-style image encoders.
  • Avoid it when data is scarce, latency is tight, or you're on the edge. A compact CNN or ConvNeXt usually wins on small datasets and on milliwatt budgets; ViT's lack of inductive bias is dead weight there.

ViT vs CNN and the windowed variants

ViT (vanilla)CNN (ResNet)Swin TransformerConvNeXtHybrid (CNN→ViT)
Layer-1 receptive fieldwhole image3×3 locallocal window7×7 locallocal, then global
Inductive biasalmost nonelocality + equivariancelocality (shifted windows)locality + equivariancepartial
Attention cost vs tokensO(N²)n/aO(N) (fixed windows)n/aO(N²) on reduced N
Data efficiencypoor without pre-traininggoodgoodgoodbetter than vanilla ViT
Multi-scale featuressingle scalepyramidalpyramidalpyramidalpyramidal
Good fit forlarge-scale classifyedge, small datadetection, segmentationCNN-era drop-in upgrademedium-data classify

The headline is that vanilla ViT trades all the convolutional priors for raw flexibility, and pays for it in data hunger. Swin reintroduces locality and a feature pyramid to make Transformers practical for detection and segmentation; ConvNeXt goes the other direction, redesigning a pure CNN with Transformer-era tricks (large kernels, LayerNorm, GELU) and matching ViT accuracy. The two families have converged.

What the numbers actually say

  • The data threshold is real. In the original paper, ViT-Base trained on ImageNet-1k (1.3M images) lost to a comparable ResNet. Pre-trained on JFT-300M (300M images), ViT-Huge hit 88.55% top-1 on ImageNet — beating the best CNN of the day, at a fraction of the pre-training compute.
  • Patch count drives cost quadratically. P=16 → N=196; P=8 → N=784 (16× bigger attention matrix); P=32 → N=49 (faster but coarser). Patch size is the single most consequential hyperparameter.
  • ViT-Base is ~86M parameters; ViT-Large ~307M; ViT-Huge ~632M. Most of the parameters live in the MLP blocks and the QKV projections, not the patch embedder.
  • DeiT changed the small-data story. In 2021, Touvron et al. trained a ViT to 83.1% top-1 on ImageNet-1k alone — no JFT — using heavy augmentation and a CNN teacher via distillation. So "ViTs always need 300M images" is a 2020 statement, not a permanent law.
  • Self-attention FLOPs at 224px, P=16: the 197×197 attention matrix is ~38k scores per head; with 12 heads over 12 layers it is a few hundred million multiply-adds — small next to the multi-billion-FLOP MLP blocks.

JavaScript: patchify and a single attention head

The two operations that make a ViT a ViT are patch extraction and scaled dot-product attention. Here they are in plain JavaScript, with no framework, to show there's no magic:

// Cut an HxW grayscale image (Float32Array, row-major) into PxP patches,
// each flattened to a length-P*P vector. Returns an array of patch vectors.
function patchify(img, H, W, P) {
  if (H % P || W % P) throw new Error("image size must be divisible by patch size");
  const patches = [];
  for (let py = 0; py < H; py += P) {
    for (let px = 0; px < W; px += P) {
      const v = new Float32Array(P * P);
      for (let r = 0; r < P; r++)
        for (let c = 0; c < P; c++)
          v[r * P + c] = img[(py + r) * W + (px + c)];
      patches.push(v);                 // one token, pre-projection
    }
  }
  return patches;                      // length N = (H/P) * (W/P)
}

// Numerically stable softmax over a row.
function softmax(row) {
  const m = Math.max(...row);
  const exps = row.map(x => Math.exp(x - m));
  const sum = exps.reduce((a, b) => a + b, 0);
  return exps.map(e => e / sum);
}

// One self-attention head. Q, K, V are N-by-dk matrices (arrays of rows).
// Returns N output rows. This is the O(N^2 * dk) core of every ViT layer.
function attentionHead(Q, K, V) {
  const N = Q.length, dk = Q[0].length, scale = 1 / Math.sqrt(dk);
  const out = [];
  for (let i = 0; i < N; i++) {
    const scores = new Array(N);
    for (let j = 0; j < N; j++) {       // every token scores every token
      let dot = 0;
      for (let d = 0; d < dk; d++) dot += Q[i][d] * K[j][d];
      scores[j] = dot * scale;
    }
    const w = softmax(scores);           // attention weights for token i
    const o = new Array(dk).fill(0);
    for (let j = 0; j < N; j++)
      for (let d = 0; d < dk; d++) o[d] += w[j] * V[j][d];
    out.push(o);
  }
  return out;
}

Two things to notice. First, patchify never mixes pixels across patch boundaries — each tile becomes an independent token, which is exactly why a ViT has to learn spatial relationships rather than getting them for free. Second, the doubly-nested loop over i and j in attentionHead is the N² you can never escape in vanilla attention.

Python / PyTorch: the embedder and an encoder block

In real code, the patch embedding is a single strided convolution — the kernel size and stride both equal the patch size, so the conv windows tile the image without overlap and each window produces one token. This is the one convolution a "convolution-free" ViT keeps:

import torch, torch.nn as nn

class PatchEmbed(nn.Module):
    """Conv with kernel=stride=P is exactly 'flatten each patch, then linear-project'."""
    def __init__(self, img_size=224, patch=16, in_ch=3, dim=768):
        super().__init__()
        self.n_patches = (img_size // patch) ** 2          # 196 for 224/16
        self.proj = nn.Conv2d(in_ch, dim, kernel_size=patch, stride=patch)

    def forward(self, x):                                  # x: [B, C, H, W]
        x = self.proj(x)                                   # [B, dim, H/P, W/P]
        return x.flatten(2).transpose(1, 2)                # [B, N, dim]

class ViTEncoderBlock(nn.Module):
    """Pre-norm Transformer block: residual MSA, then residual MLP."""
    def __init__(self, dim=768, heads=12, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))

    def forward(self, x):
        y = self.norm1(x)
        x = x + self.attn(y, y, y, need_weights=False)[0]  # self-attention
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=224, patch=16, dim=768, depth=12,
                 heads=12, n_classes=1000):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch, 3, dim)
        n = self.patch_embed.n_patches
        self.cls = nn.Parameter(torch.zeros(1, 1, dim))             # [CLS] token
        self.pos = nn.Parameter(torch.zeros(1, n + 1, dim))         # positions
        self.blocks = nn.ModuleList(ViTEncoderBlock(dim, heads) for _ in range(depth))
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, n_classes)

    def forward(self, x):
        x = self.patch_embed(x)                                     # [B, N, dim]
        cls = self.cls.expand(x.size(0), -1, -1)                    # [B, 1, dim]
        x = torch.cat([cls, x], dim=1) + self.pos                   # prepend + add pos
        for blk in self.blocks:
            x = blk(x)
        return self.head(self.norm(x)[:, 0])                        # read [CLS]

The whole architecture is barely 40 lines because it reuses the language Transformer wholesale. The only vision-specific code is PatchEmbed and the choice to read the [CLS] token at the end. Everything in between — LayerNorm, multi-head attention, GELU MLP, residuals — is identical to the encoder in a text Transformer.

Variants worth knowing

DeiT (Data-efficient Image Transformer, 2021). Trains a ViT on ImageNet-1k alone using strong augmentation and a distillation token that learns from a CNN teacher. Demolished the "ViTs need 300M images" assumption.

Swin Transformer (2021). Computes attention inside local windows and shifts the window grid every other layer so information leaks across boundaries. Restores O(N) attention and a feature pyramid, making Transformers the backbone of choice for object detection and segmentation.

MAE (Masked Autoencoders, 2021). Self-supervised pre-training: mask out 75% of the patches, ask a ViT decoder to reconstruct them. Because most patches are dropped, the encoder only processes 25% of tokens, making pre-training fast and scalable without labels.

DINO / DINOv2. Self-distillation without labels that produces ViT features whose attention maps segment objects for free — the [CLS] token's attention lands on the foreground object with no segmentation supervision at all.

ConvNeXt (the counter-argument). Not a ViT, but the direct rebuttal: a pure ConvNet modernized with large kernels, LayerNorm, and GELU that matches ViT accuracy, showing much of ViT's gain was the training recipe, not the attention.

Common bugs and edge cases

  • Image size not divisible by patch size. 224/16 is clean, but feed a 225×225 image to a P=16 model and patchify either crops, pads, or crashes. Resize or pad to a multiple of P before patchifying.
  • Forgetting positional embeddings. Drop them and the model becomes a permutation-invariant bag of patches — it can still hit decent accuracy on some datasets, which masks the bug, but spatial reasoning collapses.
  • Position-embedding mismatch on fine-tuning at a new resolution. Pre-train at 224px (196 positions), fine-tune at 384px (576 positions), and the learned position table no longer fits. You must 2-D interpolate the positional embeddings to the new grid — a step beginners skip and then wonder why accuracy tanks.
  • Reading the wrong output token. The classification head reads index 0 (the [CLS] token), not a patch token. Off-by-one here silently trains on a single patch's representation.
  • Attention memory blowups. Materializing the N×N matrix at high resolution OOMs the GPU. Use a fused/FlashAttention kernel, or switch to windowed attention, before scaling up token count.
  • Training from scratch on a tiny dataset. Without pre-training or distillation, a ViT will be soundly beaten by a ResNet a tenth its size. The fix is almost never "more ViT" — it's "more data, or a pre-trained checkpoint, or a CNN."

Frequently asked questions

Why does a ViT need so much more training data than a CNN?

Convolutions hard-code two inductive biases — locality and translation equivariance — so a CNN already "knows" that nearby pixels matter and that a cat is a cat wherever it appears. A ViT has neither bias baked in; it must learn them from data. On ImageNet-1k alone (1.3M images) a ViT-Base loses to a ResNet, but pre-trained on JFT-300M (300M images) it overtakes every CNN. The data is the price of throwing away the priors.

What exactly is the [CLS] token and why prepend it?

The [CLS] token is a single learnable vector prepended to the patch sequence, borrowed from BERT. It owns no image content of its own — its only job is to attend to every patch across the encoder layers and aggregate a whole-image summary. The classification head reads the final-layer [CLS] vector. You can skip it and global-average-pool the patch tokens instead, which works about as well once you re-tune the learning rate.

What is the computational cost of self-attention on image patches?

Self-attention is O(N²·d) in time and O(N²) in memory for N tokens of dimension d, because every patch attends to every other patch. With 16×16 patches a 224×224 image gives N=196 tokens — cheap. But halve the patch size to 8×8 and N quadruples to 784, so the attention matrix grows 16×. That quadratic blow-up is why high-resolution ViTs need windowed attention like Swin (O(N) with fixed windows).

Do Vision Transformers need positional embeddings?

Yes — self-attention is permutation-invariant, so without position information a ViT sees a bag of patches and cannot tell top-left from bottom-right. ViT adds a learnable 1-D positional embedding to each patch token. Interestingly, the learned embeddings recover 2-D structure on their own: nearby patches end up with similar position vectors even though the model was never told the patches form a grid.

How is a ViT patch embedding actually computed?

Each P×P×C patch is flattened to a vector of length P²·C and multiplied by a learned matrix to produce a D-dimensional token. The slick trick is that this is mathematically identical to a single Conv2d layer with kernel size P and stride P — so most implementations literally use one strided convolution as the patch embedder. It is the one convolution a ViT keeps.

Is a Vision Transformer better than a CNN in 2026?

It depends on scale. With huge pre-training data and modern recipes, ViTs and hybrid ViT/CNN backbones top most ImageNet and detection leaderboards. But on small datasets, on edge devices, or when latency matters, well-tuned CNNs (or ConvNeXt, a CNN redesigned with Transformer-era tricks) remain competitive or better. The honest answer is that the architectures have converged more than the hype suggests.