Data Structures

k-d Tree

Binary tree partition of k-dimensional space

A k-d tree is a binary tree that partitions k-dimensional space by alternating axes at each level — split on x at the root, y at depth 1, z at depth 2, then back to x. It accelerates nearest-neighbor and range queries from O(n) to O(log n) — until the curse of dimensionality strikes around k = 20.

  • Build (median pivot)O(n log n)
  • NN query (low-d)O(log n)
  • NN query (high-d)O(n) — no speedup
  • Practical ceilingk ≈ 20
  • SpaceO(n)

Interactive visualization

Press play, or step through manually. The visualization is yours to drive — try it before reading on.

Open visualization fullscreen ↗

Watch the 60-second explainer

A condensed visual walkthrough — narrated, captioned, under a minute.

How a k-d tree works

A k-d tree (short for "k-dimensional") is what you get when you generalize a binary search tree from 1D numbers to k-dimensional points. Bentley introduced it in 1975. The defining trick: at each level of the tree, you split on a different coordinate axis, cycling through all k of them.

To build a balanced k-d tree on n points:

  1. At the root (depth 0), pick the axis you'll split on — typically axis 0 (x), or the axis with the largest variance.
  2. Find the median point along that axis. Store the median at the root; everything below the median goes to the left subtree, everything above to the right.
  3. Recurse on each subtree, advancing to the next axis: depth 1 splits on axis 1 (y), depth 2 on axis 2 (z), depth k on axis 0 again.

The result is a binary tree where every node represents an axis-aligned hyperplane. The two halves of the data are cleanly separated, so a search for points "near" a query can prune entire subtrees by checking whether the query's bounding ball crosses the splitting plane.

The clever part of k-d trees isn't construction — it's the search. Naive descent (always go left if the query is below the splitting plane) finds a leaf, but not necessarily the nearest one. The true nearest could live on the other side of a splitting plane that's almost touching the query.

The standard nearest-neighbor algorithm:

  1. Recursively descend to the leaf that contains the query point. Initialize best = that leaf's point and bestDist = distance(query, best).
  2. Unwind the recursion. At each ancestor, do two checks:
    • Is this ancestor's stored point closer than best? Update if so.
    • Is the perpendicular distance from query to the ancestor's splitting plane less than bestDist? If yes, the other subtree might contain a closer point — recurse into it. If no, prune the entire other subtree.

This backtracking-with-pruning is what gives expected O(log n) NN queries — the splitting planes prune the work as long as the dimensionality stays low.

The curse of dimensionality — why k-d trees break above k = 20

In two dimensions, the splitting-plane prune nearly always succeeds: the volume "near" the query is a tiny disc, and most of the tree is far away. In high dimensions, the geometry betrays you. A unit hypercube in d = 100 has corners at distance √100 = 10 from its center; a unit ball touches none of those corners. The pruning predicate "is the query's ball within distance r of the splitting plane?" answers yes for almost every plane, so the search backtracks into every subtree. You end up doing a brute-force scan with extra overhead.

Empirical rule: k-d trees give meaningful speedup up to roughly 20 dimensions on Euclidean data. Beyond that, switch to approximate methods — locality-sensitive hashing (LSH), product quantization, or graph-based indexes like HNSW (which power most modern vector databases like Pinecone, Milvus, and Weaviate).

k-d tree vs ball tree vs R-tree

k-d treeBall treeR-treeVP-tree
Partition primitiveAxis-aligned hyperplaneHypersphereBounding rectangleSphere around vantage point
Best for d ≤ 8FastestSlightly slowerDesigned for d = 2 or 3Competitive
Best for 8 ≤ d ≤ 50Degrades quicklyBest of the threeNot applicableStrong, metric-only
Best for d > 100UselessUselessUselessUseless
InsertionHard — typically rebuildHard — typically rebuildNative — designed for itHard
Distance metricEuclidean / LpAny metricEuclidean / overlapAny metric
MemoryO(n)O(n)O(n)O(n)
Indexes points or boxesPointsPointsBounding boxesPoints

scikit-learn's kNN classifier exposes both kd_tree and ball_tree as backends and auto-switches to brute force above ~30 dimensions. That's the practical envelope.

JavaScript implementation — nearest-neighbor in k dimensions

This implementation uses dimension-cycling splits and median pivoting. It's general-purpose: pass k = 2 for 2D points, k = 3 for 3D, k = 8 for an 8-d feature vector. Watch how the pruning check at the splitting plane lets the search skip entire subtrees.

class KDTree {
  constructor(points, k) {
    this.k = k;
    this.root = this._build(points, 0);
  }

  _build(pts, depth) {
    if (pts.length === 0) return null;
    const axis = depth % this.k;
    pts.sort((a, b) => a[axis] - b[axis]);
    const mid = Math.floor(pts.length / 2);
    return {
      point: pts[mid],
      axis,
      left:  this._build(pts.slice(0, mid),     depth + 1),
      right: this._build(pts.slice(mid + 1),    depth + 1),
    };
  }

  _dist2(a, b) {
    let s = 0;
    for (let i = 0; i < this.k; i++) { const d = a[i] - b[i]; s += d * d; }
    return s;
  }

  // Returns the single nearest point to q (Euclidean).
  nearest(q) {
    let best = null, bestDist = Infinity;
    const search = (node) => {
      if (!node) return;
      const d2 = this._dist2(q, node.point);
      if (d2 < bestDist) { best = node.point; bestDist = d2; }
      const diff = q[node.axis] - node.point[node.axis];
      const near  = diff < 0 ? node.left  : node.right;
      const far   = diff < 0 ? node.right : node.left;
      search(near);
      // Prune: only descend into far side if the splitting plane is closer
      // than the current best — otherwise no point on that side can win.
      if (diff * diff < bestDist) search(far);
    };
    search(this.root);
    return best;
  }

  // k-nearest neighbors via a max-heap of size kNeighbors.
  kNearest(q, kNeighbors) {
    const heap = [];  // [dist², point], max at index 0
    const push = (d, p) => {
      heap.push([d, p]);
      heap.sort((a, b) => b[0] - a[0]);  // toy max-heap; use a real one in prod
      if (heap.length > kNeighbors) heap.shift();
    };
    const search = (node) => {
      if (!node) return;
      const d2 = this._dist2(q, node.point);
      if (heap.length < kNeighbors || d2 < heap[0][0]) push(d2, node.point);
      const diff = q[node.axis] - node.point[node.axis];
      const near = diff < 0 ? node.left  : node.right;
      const far  = diff < 0 ? node.right : node.left;
      search(near);
      const worst = heap.length < kNeighbors ? Infinity : heap[0][0];
      if (diff * diff < worst) search(far);
    };
    search(this.root);
    return heap.map(([_, p]) => p);
  }
}

// 2D — fast and useful.
const tree2D = new KDTree([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]], 2);
console.log(tree2D.nearest([6, 5]));   // [5, 4]

// 8D — still log-time-ish if data is not adversarial.
const tree8D = new KDTree(eightDimVectors, 8);

// 100D — you'll see no speedup over brute-force. Switch tools.

Python implementation

import math

class KDNode:
    __slots__ = ('point', 'axis', 'left', 'right')
    def __init__(self, point, axis, left, right):
        self.point, self.axis, self.left, self.right = point, axis, left, right

class KDTree:
    def __init__(self, points, k):
        self.k = k
        self.root = self._build(list(points), 0)

    def _build(self, pts, depth):
        if not pts:
            return None
        axis = depth % self.k
        pts.sort(key=lambda p: p[axis])
        mid = len(pts) // 2
        return KDNode(pts[mid], axis,
                      self._build(pts[:mid], depth + 1),
                      self._build(pts[mid + 1:], depth + 1))

    def _dist2(self, a, b):
        return sum((a[i] - b[i]) ** 2 for i in range(self.k))

    def nearest(self, q):
        best = [None, math.inf]   # [point, dist²]
        def search(node):
            if node is None:
                return
            d2 = self._dist2(q, node.point)
            if d2 < best[1]:
                best[0], best[1] = node.point, d2
            diff = q[node.axis] - node.point[node.axis]
            near, far = (node.left, node.right) if diff < 0 else (node.right, node.left)
            search(near)
            if diff * diff < best[1]:
                search(far)
        search(self.root)
        return best[0]

    def k_nearest(self, q, k_neighbors):
        import heapq
        heap = []   # max-heap of (-dist², counter, point)
        counter = 0
        def search(node):
            nonlocal counter
            if node is None:
                return
            d2 = self._dist2(q, node.point)
            if len(heap) < k_neighbors:
                heapq.heappush(heap, (-d2, counter, node.point))
                counter += 1
            elif d2 < -heap[0][0]:
                heapq.heappushpop(heap, (-d2, counter, node.point))
                counter += 1
            diff = q[node.axis] - node.point[node.axis]
            near, far = (node.left, node.right) if diff < 0 else (node.right, node.left)
            search(near)
            worst = math.inf if len(heap) < k_neighbors else -heap[0][0]
            if diff * diff < worst:
                search(far)
        search(self.root)
        return [p for _, _, p in sorted(heap, key=lambda x: -x[0])]

For Python production use, scipy's cKDTree is a C-backed implementation that's ~100× faster than this teaching code; pass balanced_tree=True for static data and compact_nodes=True for memory savings.

Variants

  • bd-tree (box-decomposition tree). Mounts and Arya's tree that splits with both axis-aligned planes and "shrinking" boxes that wrap around dense clusters. Provides better worst-case guarantees in high dimensions.
  • Variance-split k-d tree. Choose the split axis as the one with maximum variance instead of cycling. Produces tighter, more balanced trees on skewed data — used in scikit-learn and OpenCV.
  • Approximate k-d tree (FLANN / ANN). Stop searching after a fixed budget of leaves; you trade recall for speed. The "approximate nearest neighbor" library FLANN, used in OpenCV, builds randomized k-d forests for this.
  • k-d-B tree. A disk-based version that pages nodes the way B-trees do, used in older spatial databases.
  • Random projection k-d tree. Project to a random direction at each split instead of an axis-aligned one. Performs better above d ≈ 30 because it breaks the alignment that the curse of dimensionality exploits.
  • Static-rebuild k-d tree. Rebuild from scratch periodically rather than supporting incremental insert. Standard practice — most use cases query a fixed point set.

Common bugs and edge cases

  • Forgetting the pruning check. Without the "is the splitting plane closer than the current best?" test, NN search becomes a full traversal — O(n), not O(log n). The asymmetric search-and-backtrack pattern is the entire point.
  • Building from a stream of inserts. A k-d tree with no rebalancing degenerates fast under non-random insert order. Either batch-build from sorted medians, or use a self-balancing variant like the bd-tree.
  • Mixing dimension scales. If x is in millimeters and y is in kilometers, splits on y dominate every level and the tree is essentially 1D. Always normalize coordinates (z-score or min-max) before building.
  • Curse-of-dimensionality denial. Engineers paste a k-d tree into a 256-d image-embedding NN search, see no speedup, and assume the implementation is broken. It's not — the tree is fine, the dimension is the problem. Switch to HNSW or LSH.
  • Median pivot on duplicates. If many points share the median coordinate, the partition splits unevenly and the recursion can stack-overflow. Use median-of-three with explicit tiebreak or insert points slightly perturbed.
  • Wrong distance metric. The pruning check uses perpendicular distance to the axis-aligned plane, which is valid for Euclidean and L∞. It is not valid for cosine distance — convert vectors to unit length and use Euclidean instead.

Frequently asked questions

Why do k-d trees fail in high dimensions?

As dimension grows, the volume of a query ball relative to the bounding box explodes, forcing the search to visit nearly every leaf. Empirically, k-d trees offer real speedup up to ~20 dimensions; above that they degenerate to O(n), the same as brute-force scan. This is one face of the curse of dimensionality.

What's the right way to choose split axes — cycling or variance?

Cycling (level 0 splits on x, level 1 on y, etc.) is simple and works well when data is roughly isotropic. Variance-based splitting picks the axis with the largest spread at each node; it produces tighter trees on skewed data, at the cost of recomputing variance per node. For real workloads, variance splitting plus median pivot is standard.

How do you do nearest-neighbor search efficiently?

Descend to the leaf containing the query, record that leaf's distance as the current best, then walk back up. At each ancestor, check whether the splitting plane is closer than the current best — if so, the other subtree might contain a closer point and you must recurse into it. This 'backtracking with pruning' gives O(log n) expected for low dimensions.

k-d tree vs ball tree vs R-tree — which to use?

k-d trees are fastest in low dimensions (≤8) on Euclidean data. Ball trees use spherical bounds and degrade more gracefully in moderate dimensions (10–50), which is why scikit-learn defaults to ball tree for kNN above d=20. R-trees index rectangles instead of points and dominate GIS work. For d > 100, abandon all three and use approximate methods like LSH or HNSW.

Why are k-d trees usually built once and never modified?

Inserting into an existing k-d tree without rebalancing skews the splits, ruining query performance. There's no clean rotation analog to red-black trees. Production systems either rebuild the tree periodically or use a self-balancing variant like the bd-tree. For high-churn data, prefer R-trees or grids.

What's the canonical use case for k-d trees?

Photon mapping in offline rendering, k-NN classification on small feature spaces, point-cloud nearest-neighbor in robotics (LIDAR registration with d=3), color quantization (d=3 RGB), and computational geometry primitives like closest-pair. Above d=20 the tree stops helping and people switch to approximate methods.