TorchLean API

NN.Proofs.Autograd.Tape.Core.FDeriv

FDeriv #

Analytic (HasFDerivAt/fderiv) correctness for tape-style SSA/DAG graphs.

NN/Proofs/Autograd/Tape/Core/Soundness.lean proves the global JVP/VJP adjointness law for DAG graphs against the tensor dot product.

This file adds the analytic upgrade (spec-level over ):

PyTorch correspondence / citations #

noncomputable def Proofs.Autograd.toVecT {s : Spec.Shape} (t : Spec.Tensor s) :

Vectorize a tensor by flattening it (spec flattening order) and then using the Euclidean equivalence.

Instances For
    noncomputable def Proofs.Autograd.ofVecT {s : Spec.Shape} (v : Vec s.size) :

    Inverse of toVecT: interpret a vector as a tensor of shape s.

    Instances For

      Total number of scalar coordinates in a heterogeneous context shape list.

      Instances For
        @[reducible, inline]

        A vectorized context: one Euclidean vector containing all TList Γ entries concatenated.

        Instances For
          noncomputable def Proofs.Autograd.vecOfFun {n : } (f : Fin n) :
          Vec n

          Build a Euclidean vector from its coordinate function.

          This is the small helper used throughout the file when reindexing vectors across context concatenation and shape casts.

          Instances For
            @[simp]
            theorem Proofs.Autograd.vecOfFun_apply {n : } (f : Fin n) (i : Fin n) :
            (vecOfFun f).ofLp i = f i
            @[simp]
            theorem Proofs.Autograd.vecOfFun_eta {n : } (v : Vec n) :
            (vecOfFun fun (i : Fin n) => v.ofLp i) = v

            Flatten a typed context TList Γ into one big Euclidean vector.

            Unlike PyTorch’s runtime “saved tensor list”, this is an actual typed isomorphism: shapes are tracked in Γ, so the split points are definitional from ctxSize.

            Instances For

              Inverse of flattenCtx: split a CtxVec Γ back into a TList Γ.

              Instances For
                noncomputable def Proofs.Autograd.castVec {n m : } (h : n = m) :
                Vec nVec m

                Cast a Vec n to Vec m along an equality, by reindexing coordinates.

                Instances For
                  @[simp]
                  theorem Proofs.Autograd.castVec_apply {n m : } (h : n = m) (v : Vec n) (i : Fin m) :
                  (castVec h v).ofLp i = v.ofLp (Fin.cast i)
                  @[simp]
                  theorem Proofs.Autograd.castVec_rfl {n : } (v : Vec n) :
                  castVec v = v
                  @[simp]
                  theorem Proofs.Autograd.castVec_add {n m : } (h : n = m) (u v : Vec n) :
                  castVec h (u + v) = castVec h u + castVec h v
                  @[simp]
                  theorem Proofs.Autograd.castVec_smul {n m : } (h : n = m) (r : ) (v : Vec n) :
                  castVec h (r v) = r castVec h v
                  @[simp]
                  theorem Proofs.Autograd.castVec_castVec {n m k : } (h₁ : n = m) (h₂ : m = k) (v : Vec n) :
                  castVec h₂ (castVec h₁ v) = castVec v
                  theorem Proofs.Autograd.inner_castVec_castVec {n m : } (h : n = m) (x y : Vec n) :
                  inner (castVec h x) (castVec h y) = inner x y

                  castVec preserves the Euclidean inner product.

                  This is the core “cast isometry” lemma used throughout the vectorized graph development.

                  theorem Proofs.Autograd.sum_spec_dim {n : } {s : Spec.Shape} (values : Fin nSpec.Tensor s) :
                  (Spec.Tensor.dim values).sumSpec = i : Fin n, (values i).sumSpec

                  sum_spec over an outer dimension is a sum over slices.

                  This tensor-level “Fubini rule” is used to relate Spec.dot to Euclidean inner products after vectorization.

                  theorem Proofs.Autograd.toVecT_dim_apply {n : } {s : Spec.Shape} (hmpos : 0 < s.size) (f : Fin nSpec.Tensor s) (p : Fin n × Fin s.size) :

                  Coordinate characterization of toVecT on a tensor .dim n s.

                  Informally, the vectorization order is the standard product order induced by finProdFinEquiv.

                  theorem Proofs.Autograd.inner_toVecT_dim {n : } {s : Spec.Shape} (a b : Fin nSpec.Tensor s) :
                  inner (toVecT (Spec.Tensor.dim a)) (toVecT (Spec.Tensor.dim b)) = i : Fin n, inner (toVecT (a i)) (toVecT (b i))

                  toVecT turns dot products on .dim n s into sums of Euclidean inner products over slices.

                  Main compatibility lemma: tensor dot equals Euclidean inner product of vectorizations.

                  This is the bridge between soundness.lean (stated using Spec.dot) and the analytic theorems here (stated using Euclidean inner).

                  noncomputable def Proofs.Autograd.appendVec {m n : } (a : Vec m) (b : Vec n) :
                  Vec (m + n)

                  Concatenate two Euclidean vectors using Fin.append.

                  Instances For
                    theorem Proofs.Autograd.inner_append {m n : } (a c : Vec m) (b d : Vec n) :

                    Inner product of concatenated vectors splits as a sum of inner products.

                    TList.dotList equals Euclidean inner product of flattenCtx.

                    This shows that the “context inner product” used in tape soundness is exactly the Euclidean inner product on the vectorized context representation.

                    noncomputable def Proofs.Autograd.castCtxVec {Γ₁ Γ₂ : List Spec.Shape} (h : Γ₁ = Γ₂) :
                    CtxVec Γ₁CtxVec Γ₂

                    Cast a vectorized context along an equality of shape lists (reindexing coordinates).

                    Instances For
                      @[simp]
                      @[simp]
                      theorem Proofs.Autograd.castCtxVec_cast {Γ₁ Γ₂ Γ₃ : List Spec.Shape} (h₁ : Γ₁ = Γ₂) (h₂ : Γ₂ = Γ₃) (v : CtxVec Γ₁) :
                      castCtxVec h₂ (castCtxVec h₁ v) = castCtxVec v

                      The next few lemmas are bookkeeping for splitting/concatenating vectorized contexts. They are “obvious” from the list structure of Γ, but it is useful to expose them as named facts so that the calculus proofs later can use them without redoing shape arithmetic.

                      theorem Proofs.Autograd.inner_castCtxVec {Γ₁ Γ₂ : List Spec.Shape} (h : Γ₁ = Γ₂) (x : CtxVec Γ₁) (y : CtxVec Γ₂) :

                      castCtxVec is inner-product preserving (up to flipping the cast on the other argument).

                      ctxSize respects list append (sizes add).

                      Specialized ctxSize_append for snoc (Γ ++ [τ]).

                      noncomputable def Proofs.Autograd.snocCtx {Γ : List Spec.Shape} {τ : Spec.Shape} (ctx : CtxVec Γ) (t : Vec τ.size) :
                      CtxVec (Γ ++ [τ])

                      Append one tensor-vector block to a vectorized context.

                      Instances For
                        noncomputable def Proofs.Autograd.unsnocCtx {Γ : List Spec.Shape} {τ : Spec.Shape} (ctx : CtxVec (Γ ++ [τ])) :

                        Inverse of snocCtx: split CtxVec (Γ ++ [τ]) into its prefix and last block.

                        Instances For
                          theorem Proofs.Autograd.unsnocCtx_snocCtx {Γ : List Spec.Shape} {τ : Spec.Shape} (ctx : CtxVec Γ) (t : Vec τ.size) :
                          unsnocCtx (snocCtx ctx t) = (ctx, t)

                          unsnocCtx (snocCtx ctx t) = (ctx, t).

                          theorem Proofs.Autograd.snocCtx_unsnocCtx {Γ : List Spec.Shape} {τ : Spec.Shape} (ctx : CtxVec (Γ ++ [τ])) :
                          snocCtx (unsnocCtx ctx).1 (unsnocCtx ctx).2 = ctx

                          snocCtx (unsnocCtx ctx) = ctx.

                          noncomputable def Proofs.Autograd.Node.forwardVec {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) :
                          CtxVec ΓVec τ.size

                          Vectorized forward map of a tape Node: CtxVec Γ → Vec (Shape.size τ).

                          Instances For
                            noncomputable def Proofs.Autograd.Node.jvpVec {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) :
                            CtxVec ΓCtxVec ΓVec τ.size

                            Vectorized JVP of a tape Node: the node-level forward-mode action on tangents.

                            Instances For
                              noncomputable def Proofs.Autograd.Node.vjpVec {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) :
                              CtxVec ΓVec τ.sizeCtxVec Γ

                              Vectorized VJP of a tape Node: pushes a cotangent vector back to the input context.

                              Instances For
                                theorem Proofs.Autograd.Node.correct_inner {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) (ctxV dctxV : CtxVec Γ) (δV : Vec τ.size) :
                                inner (node.jvpVec ctxV dctxV) δV = inner dctxV (node.vjpVec ctxV δV)

                                Vectorized form of Node.correct (adjointness law).

                                Statement: ⟪jvp(x,dx), δ⟫ = ⟪dx, vjp(x,δ)⟫.

                                def Proofs.Autograd.Graph.evalVec {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV : CtxVec Γ) :
                                CtxVec (Γ ++ ss)

                                Vectorized evaluation of a tape Graph.

                                Returns a CtxVec (Γ ++ ss) containing the original inputs and all intermediate node outputs.

                                Instances For
                                  def Proofs.Autograd.Graph.jvpVec {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV dxV : CtxVec Γ) :
                                  CtxVec (Γ ++ ss)

                                  Vectorized JVP for a whole graph: forward-mode derivative of evalVec.

                                  Instances For
                                    def Proofs.Autograd.Graph.backpropVec {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV : CtxVec Γ) (seedV : CtxVec (Γ ++ ss)) :

                                    Vectorized reverse-mode accumulation (VJP) for a whole graph.

                                    seedV is a cotangent for the entire Γ ++ ss context (inputs plus intermediates), matching the global tape soundness theorem.

                                    Instances For

                                      The next theorem is exactly soundness.lean rewritten into Euclidean vector form. It is the key input to later “backprop = (fderiv eval)†” proofs.

                                      theorem Proofs.Autograd.Graph.backprop_correct_inner {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV dxV : CtxVec Γ) (seedV : CtxVec (Γ ++ ss)) :
                                      inner (g.jvpVec xV dxV) seedV = inner dxV (g.backpropVec xV seedV)

                                      Vectorized tape soundness: ⟪jvp, seed⟫ = ⟪dx, backprop seed⟫.

                                      structure Proofs.Autograd.NodeFDerivCorrect {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) :

                                      Per-node analytic correctness assumption: JVP is the Fréchet derivative.

                                      This is the hypothesis that upgrades the dot-level soundness theorem into an fderiv statement.

                                      Instances For

                                        Graph predicate: every node satisfies NodeFDerivCorrect.

                                        Instances For
                                          structure Proofs.Autograd.NodeFDerivCorrectAt {Γ : List Spec.Shape} {τ : Spec.Shape} (node : Node Γ τ) (xV : CtxVec Γ) :

                                          Pointwise per-node analytic correctness.

                                          Used when a node is only differentiable under side conditions at a particular basepoint xV (e.g. inv, sqrt, log, or piecewise ops).

                                          Instances For
                                            def Proofs.Autograd.NodeFDerivCorrect.at {Γ : List Spec.Shape} {τ : Spec.Shape} {node : Node Γ τ} (hn : NodeFDerivCorrect node) (xV : CtxVec Γ) :

                                            Specialize a global NodeFDerivCorrect proof to a particular basepoint.

                                            This is the common “turn an everywhere-differentiable node into a pointwise differentiable node” adapter used when assembling GraphFDerivCorrectAt proofs.

                                            Instances For

                                              Pointwise graph predicate: every node is differentiable at the actual intermediate values.

                                              Note the recursion uses Graph.evalVec to compute the basepoint for each successive node.

                                              Instances For
                                                noncomputable def Proofs.Autograd.Graph.appendCLM (m n : ) :
                                                Vec m × Vec n →L[] Vec (m + n)

                                                Fin.append packaged as a continuous linear map on Euclidean vectors.

                                                Instances For
                                                  noncomputable def Proofs.Autograd.Graph.castCLM {n m : } (h : n = m) :

                                                  castVec packaged as a continuous linear map (finite-dimensional, hence continuous).

                                                  Instances For

                                                    Continuous linear map version of snocCtx (concatenation + cast).

                                                    Instances For
                                                      theorem Proofs.Autograd.Graph.hasFDerivAt_evalVec_and_jvp {Γ ss : List Spec.Shape} (g : Graph Γ ss) (hg : GraphFDerivCorrect g) (xV : CtxVec Γ) :
                                                      ∃ (D : CtxVec Γ →L[] CtxVec (Γ ++ ss)), HasFDerivAt g.evalVec D xV ∀ (dxV : CtxVec Γ), g.jvpVec xV dxV = D dxV

                                                      Main induction: evalVec is differentiable and its derivative agrees with jvpVec.

                                                      This is the technical heart of the jvp = fderiv theorem.

                                                      Convenience corollaries:

                                                      Once we have HasFDerivAt evalVec = jvpVec, the rest are immediate: jvpVec = fderiv, then backpropVec = (fderiv evalVec)† by the inner-product characterization of adjoints.

                                                      theorem Proofs.Autograd.Graph.jvpVec_eq_fderiv {Γ ss : List Spec.Shape} (g : Graph Γ ss) (hg : GraphFDerivCorrect g) (xV dxV : CtxVec Γ) :
                                                      g.jvpVec xV dxV = (fderiv g.evalVec xV) dxV

                                                      Under GraphFDerivCorrect, the graph JVP equals the Fréchet derivative fderiv of evalVec.

                                                      theorem Proofs.Autograd.Graph.backpropVec_eq_adjoint_fderiv {Γ ss : List Spec.Shape} (g : Graph Γ ss) (hg : GraphFDerivCorrect g) (xV : CtxVec Γ) (seedV : CtxVec (Γ ++ ss)) :

                                                      Main analytic theorem: backpropVec equals the adjoint of the derivative of evalVec.

                                                      This is the proof-level formalization of “reverse-mode computes a VJP”, stated as an equality of linear maps in a Euclidean space.

                                                      theorem Proofs.Autograd.Graph.hasFDerivAt_evalVec_and_jvp_at {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV : CtxVec Γ) :
                                                      ∀ (a : GraphFDerivCorrectAt g xV), ∃ (D : CtxVec Γ →L[] CtxVec (Γ ++ ss)), HasFDerivAt g.evalVec D xV ∀ (dxV : CtxVec Γ), g.jvpVec xV dxV = D dxV

                                                      Pointwise induction: evalVec is differentiable at xV, and its derivative agrees with jvpVec.

                                                      This is the version used for graphs involving non-smooth or partial primitives, where we only assume differentiability at the values encountered during execution.

                                                      Pointwise corollaries: these mirror jvpVec_eq_fderiv and backpropVec_eq_adjoint_fderiv, but only require GraphFDerivCorrectAt at the specific execution point.

                                                      theorem Proofs.Autograd.Graph.jvpVec_eq_fderiv_at {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV dxV : CtxVec Γ) :
                                                      ∀ (a : GraphFDerivCorrectAt g xV), g.jvpVec xV dxV = (fderiv g.evalVec xV) dxV

                                                      Pointwise version of jvpVec_eq_fderiv.

                                                      theorem Proofs.Autograd.Graph.backpropVec_eq_adjoint_fderiv_at {Γ ss : List Spec.Shape} (g : Graph Γ ss) (xV : CtxVec Γ) (seedV : CtxVec (Γ ++ ss)) :

                                                      Pointwise version of backpropVec_eq_adjoint_fderiv.