TorchLean API

NN.Proofs.Autograd.Tape.Core.Soundness

Soundness #

Tape-style (SSA/DAG) reverse-mode soundness for the proved-correct layer.

We model a dynamic graph as a sequence of nodes that may reference any previously computed values (so sharing/fan-out is allowed). For each node we assume a local JVP/VJP adjointness law, then prove the global reverse-mode accumulation algorithm is sound.

This is a proof-only layer; the runtime engine in NN.Runtime.Autograd.Engine is an executable implementation of the same idea.

PyTorch correspondence / citations #

References (background):

theorem Proofs.Autograd.tensor_cast_shape_proof_irrel {α : Type} {s₁ s₂ : Spec.Shape} (t : Spec.Tensor α s₁) (p q : s₁ = s₂) :
@[reducible, inline]

A heterogeneous list of tensors indexed by a list of shapes.

This is the “typed context” used by the tape model: TList Γ stores one tensor for each shape in the list Γ.

PyTorch analogy: the tape stores “saved tensors”/intermediates for backward, but PyTorch stores them in an untyped runtime list; here the shapes are tracked in the type.

Implementation note: we reuse the type-level context container from the backend-generic tape development (Proofs.Autograd.Algebra.TList) and specialize it to . This avoids duplicating the basic “heterogeneous list indexed by shapes” encoding in two different places.

Instances For
    @[reducible, inline]

    Constructor aliases for the specialized TList.

    We reuse the inductive type from Proofs.Autograd.Algebra.TList, so its constructors are actually Proofs.Autograd.Algebra.TList.nil/cons. A few analytic files expect the shorter names Proofs.Autograd.TList.nil/cons, so we provide them here as abbreviations.

    Instances For
      @[reducible, inline]
      abbrev Proofs.Autograd.TList.cons {s : Spec.Shape} {ss : List Spec.Shape} (x : Spec.Tensor s) (xs : TList ss) :
      TList (s :: ss)

      Constructor alias for TList.cons specialized to .

      Instances For
        @[reducible, inline]
        abbrev Proofs.Autograd.TList.get {ss : List Spec.Shape} :
        TList ss(i : Fin ss.length) → Spec.Tensor (ss.get i)

        Get the ith tensor from a context, with its shape tracked by List.get.

        Instances For
          @[reducible, inline]

          All-zero context (fills each tensor entry with zeros).

          Instances For
            @[reducible, inline]
            abbrev Proofs.Autograd.TList.add {ss : List Spec.Shape} :
            TList ssTList ssTList ss

            Pointwise addition of two contexts of the same shape list.

            Instances For
              @[reducible, inline]
              abbrev Proofs.Autograd.TList.snoc {τ : Spec.Shape} {ss : List Spec.Shape} :
              TList ssSpec.Tensor τTList (ss ++ [τ])

              Append a tensor to the end of a context.

              Instances For
                @[reducible, inline]

                Split a context of shape list ss ++ [τ] into its prefix and last tensor.

                Instances For

                  Dot product over contexts: sum of per-entry tensor dot products.

                  Informally: dotList xs ys is the “context inner product” used to state global adjointness for tape evaluation and backprop.

                  Instances For
                    @[reducible, inline]
                    abbrev Proofs.Autograd.TList.cast {ss₁ ss₂ : List Spec.Shape} (h : ss₁ = ss₂) (xs : TList ss₁) :
                    TList ss₂

                    Cast a context along an equality of shape lists.

                    Instances For
                      theorem Proofs.Autograd.TList.dotList_cast_left {ss₁ ss₂ : List Spec.Shape} (h : ss₁ = ss₂) (x : TList ss₁) (y : TList ss₂) :
                      (cast h x).dotList y = x.dotList (cast y)

                      dotList is linear in its right argument with respect to TList.add.

                      Informally: ⟪x, y + z⟫ = ⟪x, y⟫ + ⟪x, z⟫ for contexts.

                      theorem Proofs.Autograd.TList.dotList_snoc {ss : List Spec.Shape} {τ : Spec.Shape} (x y : TList ss) (a b : Spec.Tensor τ) :
                      (x.snoc a).dotList (y.snoc b) = x.dotList y + Spec.dot a b

                      dotList respects appending: dot of two snoced contexts splits into prefix + last entry.

                      Informally: ⟪(x,a), (y,b)⟫ = ⟪x,y⟫ + ⟪a,b⟫.

                      theorem Proofs.Autograd.TList.unsnoc_snoc {ss : List Spec.Shape} {τ : Spec.Shape} (xs : TList ss) (t : Spec.Tensor τ) :
                      (xs.snoc t).unsnoc = (xs, t)

                      unsnoc is a left inverse of snoc.

                      theorem Proofs.Autograd.TList.snoc_unsnoc {ss : List Spec.Shape} {τ : Spec.Shape} (xs : TList (ss ++ [τ])) :
                      xs.unsnoc.1.snoc xs.unsnoc.2 = xs

                      snoc is a right inverse of unsnoc.

                      Dotting any tensor with a zero-filled tensor gives 0.

                      This is the tensor-level fact used to show that “one-hot” cotangents behave as expected.

                      dotList x 0 = 0 for the all-zero context.

                      An index into a heterogeneous context, carrying a proof of the expected shape.

                      This lets us talk about “the ith saved tensor has shape s” without losing the shape invariant.

                      Instances For
                        def Proofs.Autograd.getIdx {Γ : List Spec.Shape} {s : Spec.Shape} (xs : TList Γ) (idx : Idx Γ s) :

                        Read an element from a context using an index with an attached shape proof.

                        Instances For

                          Build a sparse context with a single nonzero entry at idx and zeros elsewhere.

                          This is used to express “one-hot” cotangents when proving local-to-global backprop correctness.

                          Instances For
                            theorem Proofs.Autograd.TList.dotList_single {Γ : List Spec.Shape} {s : Spec.Shape} (dx : TList Γ) (idx : Idx Γ s) (v : Spec.Tensor s) :
                            dx.dotList (single idx v) = Spec.dot (getIdx dx idx) v

                            single idx v is the “one-hot” context with value v at idx, and zeros elsewhere.

                            This lemma says the context dot product against single idx v picks out the corresponding entry of dx:

                            ⟪dx, single idx v⟫ = ⟪dx[idx], v⟫.

                            A node with local JVP/VJP and an adjointness proof against the tensor dot product.

                            Instances For

                              A tape/SSA graph: nodes are appended in topological order and may reference any previous value.

                              Instances For
                                def Proofs.Autograd.Graph.eval {Γ ss : List Spec.Shape} (g : Graph Γ ss) (x : TList Γ) :
                                TList (Γ ++ ss)

                                Evaluate a tape/graph, returning the full context (inputs ++ intermediates).

                                Instances For
                                  def Proofs.Autograd.Graph.jvpCtx {Γ ss : List Spec.Shape} (g : Graph Γ ss) (x dx : TList Γ) :
                                  TList (Γ ++ ss)

                                  Evaluate the JVP (“forward-mode tangent”) of a graph, producing tangents for all values in the extended context Γ ++ ss.

                                  Instances For
                                    def Proofs.Autograd.Graph.backpropCtx {Γ ss : List Spec.Shape} (g : Graph Γ ss) (x : TList Γ) (seed : TList (Γ ++ ss)) :

                                    Reverse-mode backpropagation on a tape/graph, returning gradients for the inputs Γ.

                                    This is the proof model of what PyTorch calls “running backward” starting from an output seed cotangent and accumulating gradients at shared parents.

                                    Instances For
                                      theorem Proofs.Autograd.Graph.backprop_correct {Γ ss : List Spec.Shape} (g : Graph Γ ss) (x dx : TList Γ) (seed : TList (Γ ++ ss)) :
                                      (g.jvpCtx x dx).dotList seed = dx.dotList (g.backpropCtx x seed)

                                      Global tape soundness: if each node satisfies a local JVP/VJP adjointness law, then the global reverse-mode accumulation algorithm (backpropCtx) is correct.

                                      Informally: for any input perturbation dx and any output seed cotangent seed,

                                      ⟪JVP(g, x, dx), seed⟫ = ⟪dx, backprop(g, x, seed)⟫.

                                      This is the formal analogue of PyTorch’s guarantee that backward() computes vector–Jacobian products and accumulates them through a dynamic DAG/tape.