TorchLean API

NN.Proofs.Autograd.Runtime.Link

Link #

Link the executable runtime tape (Runtime.Autograd.Tape) to the proved SSA/DAG tape model (Proofs.Autograd.Algebra.Graph).

This file provides a small compiler from proved graphs to runtime tapes. The compiler bakes the proved vjp into each runtime node's backward closure.

What is proved here #

The core invariant making the runtime reverse loop well-founded is that compiled nodes only emit contributions to earlier node ids (pid < id).

PyTorch correspondence / citations #

This is analogous to taking a proven “graph IR” and compiling it to an executable autograd tape whose nodes carry a backward closure (PyTorch does this internally for the eager autograd engine). https://pytorch.org/docs/stable/autograd.html

Extend a tape with leaf nodes for every tensor in the input context Γ.

Each leaf has requires_grad = true and backward = ok [], so the runtime backward loop treats them as gradient accumulation slots but never produces parent contributions from them.

Instances For
    def Proofs.Autograd.Algebra.Graph.compileAuxData {α Δ : Type} [DecidableEq Spec.Shape] {Γ ss : List Spec.Shape} (g : GraphData α Δ Γ ss) (x : TList α Γ) (d : Δ) :

    Compile an executable graph (GraphData) to a runtime tape by evaluating forward nodes and baking in each node’s proved vjp into its runtime backward closure.

    PyTorch analogy: this corresponds to building a tape of autograd nodes during the forward pass, where each node stores enough information to compute parent contributions when given an upstream cotangent.

    Instances For

      Forward-pass correspondence #

      The next lemmas show that compileAuxData preserves the proved forward semantics, and that the resulting runtime tape contains exactly the evaluated context (erased to AnyTensor) in order.

      theorem Proofs.Autograd.Algebra.Graph.compileAuxData_ctx_eq_eval {α Δ : Type} [DecidableEq Spec.Shape] {Γ ss : List Spec.Shape} (g : GraphData α Δ Γ ss) (x : TList α Γ) (d : Δ) :
      (compileAuxData g x d).2 = g.eval x d

      The context returned by compileAuxData agrees with the proved GraphData.eval.

      theorem Proofs.Autograd.Algebra.Graph.compileAuxData_values_eq {α Δ : Type} [DecidableEq Spec.Shape] {Γ ss : List Spec.Shape} (g : GraphData α Δ Γ ss) (x : TList α Γ) (d : Δ) :
      Array.map (fun (node : Runtime.Autograd.Node α) => node.value) (compileAuxData g x d).1.nodes = (compileAuxData g x d).2.toAnyArray

      The compiled tape’s .value array is GraphData.eval erased to AnyTensor, in the same order.

      theorem Proofs.Autograd.Algebra.Graph.compileAuxData_nodes_size {α Δ : Type} [DecidableEq Spec.Shape] {Γ ss : List Spec.Shape} (g : GraphData α Δ Γ ss) (x : TList α Γ) (d : Δ) :

      Size bookkeeping: the compiled tape contains one runtime node for each element of Γ ++ ss.

      def Proofs.Autograd.Algebra.Graph.compileAux {α Δ : Type} [DecidableEq Spec.Shape] [CommSemiring α] {Γ ss : List Spec.Shape} (g : Graph Δ Γ ss) (x : TList α Γ) (d : Δ) :

      Compile a proved graph (Graph) to a runtime tape by evaluating forward nodes and baking in each node’s proved vjp.

      Compared to compileAuxData, this uses the pure graph interface (no explicit GraphData payload).

      Instances For
        theorem Proofs.Autograd.Algebra.Graph.compileAux_ctx_eq_eval {α Δ : Type} [DecidableEq Spec.Shape] [CommSemiring α] {Γ ss : List Spec.Shape} (g : Graph Δ Γ ss) (x : TList α Γ) (d : Δ) :
        (g.compileAux x d).2 = g.eval x d

        The context returned by compileAux agrees with the proved Graph.eval.

        theorem Proofs.Autograd.Algebra.Graph.compileAux_values_eq {α Δ : Type} [DecidableEq Spec.Shape] [CommSemiring α] {Γ ss : List Spec.Shape} (g : Graph Δ Γ ss) (x : TList α Γ) (d : Δ) :
        Array.map (fun (node : Runtime.Autograd.Node α) => node.value) (g.compileAux x d).1.nodes = (g.compileAux x d).2.toAnyArray

        The compiled tape’s .value array is Graph.eval erased to AnyTensor, in the same order.

        theorem Proofs.Autograd.Algebra.Graph.compileAux_nodes_size {α Δ : Type} [DecidableEq Spec.Shape] [CommSemiring α] {Γ ss : List Spec.Shape} (g : Graph Δ Γ ss) (x : TList α Γ) (d : Δ) :
        (g.compileAux x d).1.nodes.size = Γ.length + ss.length

        Size bookkeeping: compileAux produces Γ.length + ss.length runtime nodes.

        Full backpropagation (dense) for proofs and runtime #

        The runtime engine computes a dense gradient array, accumulating cotangents for every node in the tape (inputs and intermediates). The following definition and theorems connect that behavior to the proved backpropagation semantics.

        def Proofs.Autograd.Algebra.Graph.backpropAllCtx {α Δ : Type} [CommSemiring α] {Γ ss : List Spec.Shape} (g : Graph Δ Γ ss) (x : TList α Γ) (d : Δ) (seed : TList α (Γ ++ ss)) :
        TList α (Γ ++ ss)

        A "full" backpropagation that returns gradients for all values (Γ ++ ss), not just Γ.

        Instances For
          def Proofs.Autograd.Algebra.GraphData.backpropAllCtx {α Δ : Type} [Add α] {Γ ss : List Spec.Shape} (g : GraphData α Δ Γ ss) (x : TList α Γ) (d : Δ) (seed : TList α (Γ ++ ss)) :
          TList α (Γ ++ ss)

          “Full” backpropagation for GraphData that returns gradients for all values (Γ ++ ss), not just inputs.

          This is the GraphData-analogue of backpropAllCtx above. We keep both definitions because:

          • Graph uses [CommSemiring α] (so it can express dot products and semiring-based accumulation), while
          • GraphData only needs [Add α] here (it just adds contributions).

          Both follow the same reverse-mode accumulation structure: peel off the last node, apply its VJP to the seed on that node, add into the previous seed, and recurse.

          Instances For

            compileAux produces a runtime tape whose node ids correspond to positions in the proof context Γ ++ ss, and bakes the proved vjp into each node’s runtime backward closure.

            The theorem backwardDenseFrom_compileAux_eq_backpropAllCtx states that executing the runtime reverse-mode loop on this compiled tape matches the proved backpropAllCtx.

            Main runtime/link theorem: running the runtime dense backward loop on a tape produced by compileAux matches the proved “full backpropagation” backpropAllCtx.

            This is the formal statement that the executable engine implements the same reverse-mode accumulation semantics as the proved tape model.

            Variant of backwardDenseFrom_compileAux_eq_backpropAllCtx for the GraphData interface.

            This is useful when a graph carries extra payload Δ (e.g. parameters/config) through forward and backward closures.