TorchLean API

NN.Proofs.Autograd.Tape.Ops.Embedding.GatherRows

Gather Rows / Embedding Lookup #

This file proves the small linear-algebra fact behind token embeddings:

The theorem is deliberately stated with Fin indices. Runtime APIs often receive Nat token IDs and perform bounds checks or totalization; the clean proof primitive should not mix that IO/error policy into the VJP theorem. A later runtime bridge can say: if every Nat ID is in bounds, the runtime gather/scatter path agrees with this Fin-indexed specification.

PyTorch analogy:

def Proofs.Autograd.Embedding.matInner {rows cols : } (A B : Fin rowsFin cols) :

Matrix-style inner product for finite row/column tables.

Instances For
    def Proofs.Autograd.Embedding.gatherRows {vocab dim k : } (table : Fin vocabFin dim) (idx : Fin kFin vocab) :
    Fin kFin dim

    Gather rows from a table using in-bounds finite token IDs.

    Instances For
      def Proofs.Autograd.Embedding.scatterAddRows {vocab dim k : } (idx : Fin kFin vocab) (dY : Fin kFin dim) :
      Fin vocabFin dim

      Scatter-add row cotangents back into an embedding table.

      If the same row appears several times in idx, its gradient contributions are summed.

      Instances For
        theorem Proofs.Autograd.Embedding.gatherRows_scatterAddRows_adjoint {vocab dim k : } (dTable : Fin vocabFin dim) (idx : Fin kFin vocab) (dY : Fin kFin dim) :
        matInner (gatherRows dTable idx) dY = matInner dTable (scatterAddRows idx dY)

        Gather and scatter-add are adjoint.

        This is the local VJP theorem for embedding lookup wrt the embedding table:

        <gatherRows dTable idx, dY> = <dTable, scatterAddRows idx dY>.