Soundness #
Tape-style (SSA/DAG) reverse-mode soundness for the proved-correct layer.
We model a dynamic graph as a sequence of nodes that may reference any previously computed values (so sharing/fan-out is allowed). For each node we assume a local JVP/VJP adjointness law, then prove the global reverse-mode accumulation algorithm is sound.
This is a proof-only layer; the runtime engine in NN.Runtime.Autograd.Engine is an
executable implementation of the same idea.
PyTorch correspondence / citations #
- This file is the proof analogue of PyTorch’s dynamic autograd engine building a tape of nodes during the forward pass and running a reverse pass that accumulates gradients at shared inputs. https://pytorch.org/docs/stable/autograd.html
References (background):
- Reverse-mode AD as backpropagation on a computation graph is standard; see e.g. Baydin et al. (JMLR 2018) for an overview and terminology (JVP/VJP, duality, etc.).
A heterogeneous list of tensors indexed by a list of shapes.
This is the “typed context” used by the tape model: TList Γ stores one tensor for each shape in
the list Γ.
PyTorch analogy: the tape stores “saved tensors”/intermediates for backward, but PyTorch stores them in an untyped runtime list; here the shapes are tracked in the type.
Implementation note: we reuse the type-level context container from the backend-generic tape
development (Proofs.Autograd.Algebra.TList) and specialize it to ℝ. This avoids duplicating the
basic “heterogeneous list indexed by shapes” encoding in two different places.
Instances For
Constructor aliases for the specialized TList.
We reuse the inductive type from Proofs.Autograd.Algebra.TList, so its constructors are actually
Proofs.Autograd.Algebra.TList.nil/cons. A few analytic files expect the shorter names
Proofs.Autograd.TList.nil/cons, so we provide them here as abbreviations.
Instances For
Constructor alias for TList.cons specialized to ℝ.
Instances For
Get the ith tensor from a context, with its shape tracked by List.get.
Instances For
All-zero context (fills each tensor entry with zeros).
Instances For
Pointwise addition of two contexts of the same shape list.
Instances For
Append a tensor to the end of a context.
Instances For
Split a context of shape list ss ++ [τ] into its prefix and last tensor.
Instances For
Dot product over contexts: sum of per-entry tensor dot products.
Informally: dotList xs ys is the “context inner product” used to state global adjointness for
tape evaluation and backprop.
Instances For
Cast a context along an equality of shape lists.
Instances For
dotList respects appending: dot of two snoced contexts splits into prefix + last entry.
Informally: ⟪(x,a), (y,b)⟫ = ⟪x,y⟫ + ⟪a,b⟫.
Dotting any tensor with a zero-filled tensor gives 0.
This is the tensor-level fact used to show that “one-hot” cotangents behave as expected.
dotList x 0 = 0 for the all-zero context.
An index into a heterogeneous context, carrying a proof of the expected shape.
This lets us talk about “the ith saved tensor has shape s” without losing the shape invariant.
i.
h.
Instances For
Read an element from a context using an index with an attached shape proof.
Instances For
Build a sparse context with a single nonzero entry at idx and zeros elsewhere.
This is used to express “one-hot” cotangents when proving local-to-global backprop correctness.
Instances For
single idx v is the “one-hot” context with value v at idx, and zeros elsewhere.
This lemma says the context dot product against single idx v picks out the corresponding entry
of dx:
⟪dx, single idx v⟫ = ⟪dx[idx], v⟫.
A node with local JVP/VJP and an adjointness proof against the tensor dot product.
- forward : TList Γ → Spec.Tensor ℝ τ
forward.
- jvp : TList Γ → TList Γ → Spec.Tensor ℝ τ
jvp.
- vjp : TList Γ → Spec.Tensor ℝ τ → TList Γ
vjp.
- correct (x dx : TList Γ) (δ : Spec.Tensor ℝ τ) : Spec.dot (self.jvp x dx) δ = dx.dotList (self.vjp x δ)
correct.
Instances For
A tape/SSA graph: nodes are appended in topological order and may reference any previous value.
- nil {Γ : List Spec.Shape} : Graph Γ []
- snoc {Γ ss : List Spec.Shape} {τ : Spec.Shape} : Graph Γ ss → Node (Γ ++ ss) τ → Graph Γ (ss ++ [τ])
Instances For
Evaluate a tape/graph, returning the full context (inputs ++ intermediates).
Instances For
Evaluate the JVP (“forward-mode tangent”) of a graph, producing tangents for all values in the
extended context Γ ++ ss.
Instances For
Reverse-mode backpropagation on a tape/graph, returning gradients for the inputs Γ.
This is the proof model of what PyTorch calls “running backward” starting from an output seed cotangent and accumulating gradients at shared parents.
Instances For
Global tape soundness: if each node satisfies a local JVP/VJP adjointness law, then the global
reverse-mode accumulation algorithm (backpropCtx) is correct.
Informally: for any input perturbation dx and any output seed cotangent seed,
⟪JVP(g, x, dx), seed⟫ = ⟪dx, backprop(g, x, seed)⟫.
This is the formal analogue of PyTorch’s guarantee that backward() computes vector–Jacobian
products and accumulates them through a dynamic DAG/tape.