Gather Rows / Embedding Lookup #
This file proves the small linear-algebra fact behind token embeddings:
- forward gathers rows from an embedding table, and
- reverse mode scatters the upstream row gradients back into the table, summing repeated indices.
The theorem is deliberately stated with Fin indices. Runtime APIs often receive Nat token IDs
and perform bounds checks or totalization; the clean proof primitive should not mix that IO/error
policy into the VJP theorem. A later runtime bridge can say: if every Nat ID is in bounds, the
runtime gather/scatter path agrees with this Fin-indexed specification.
PyTorch analogy:
torch.nn.Embeddingforward is row gather.- Its backward wrt the embedding table is scatter-add over token positions.
theorem
Proofs.Autograd.Embedding.gatherRows_scatterAddRows_adjoint
{vocab dim k : ℕ}
(dTable : Fin vocab → Fin dim → ℝ)
(idx : Fin k → Fin vocab)
(dY : Fin k → Fin dim → ℝ)
:
Gather and scatter-add are adjoint.
This is the local VJP theorem for embedding lookup wrt the embedding table:
<gatherRows dTable idx, dY> = <dTable, scatterAddRows idx dY>.