TorchLean API

NN.Proofs.Autograd.Tape.Nodes.GraphComposition

Differentiable graph composition #

The DGraph wrapper packages a tape graph together with node-local NodeFDerivCorrect proofs. Composition lemmas here let us build large model-level VJP theorems from proved primitive nodes rather than reproving backprop correctness for each architecture from scratch.

DGraph (“differentiable graph”) is a small wrapper bundling a Graph together with proofs that every node in it satisfies NodeFDerivCorrect.

This is a convenience for end-to-end examples: you can build a graph incrementally with snoc, and then immediately use backpropVec_eq_adjoint_fderiv without separately threading a proof object.

Instances For
    noncomputable def Proofs.Autograd.DGraph.takeLeftCLM (m n : ) :
    Vec (m + n) →L[] Vec m

    Continuous linear map taking the left block of a concatenated Euclidean vector.

    Instances For
      noncomputable def Proofs.Autograd.DGraph.takeRightCLM (m n : ) :
      Vec (m + n) →L[] Vec n

      Continuous linear map taking the right block of a concatenated Euclidean vector.

      Instances For
        noncomputable def Proofs.Autograd.DGraph.dropMiddleCLM (Γ extra ss : List Spec.Shape) :
        CtxVec (Γ ++ extra ++ ss) →L[] CtxVec (Γ ++ ss)

        Drop an unused middle context block from a graph node context.

        When a graph dg : DGraph Γ ss is reused inside a larger context Γ ++ extra, a node that originally reads Γ ++ ss is evaluated in the actual context (Γ ++ extra) ++ ss. This projection keeps the original inputs Γ and the already-computed intermediates ss, and ignores the carried parameters in extra.

        Instances For
          noncomputable def Proofs.Autograd.DGraph.weakenNodeMiddle {Γ extra ss : List Spec.Shape} {τ : Spec.Shape} (node : Node (Γ ++ ss) τ) :
          Node (Γ ++ extra ++ ss) τ

          Reuse a node in a context that carries extra unused inputs between the original inputs and the current SSA intermediates.

          The VJP is obtained by applying the adjoint of dropMiddleCLM, so gradients land only in the original inputs and previous intermediates; the extra carried parameters receive zero contribution from this reused node.

          Instances For
            noncomputable def Proofs.Autograd.DGraph.weakenNodeMiddleFDerivCorrect {Γ extra ss : List Spec.Shape} {τ : Spec.Shape} {node : Node (Γ ++ ss) τ} (hn : NodeFDerivCorrect node) :

            Transport a global node derivative certificate across weakenNodeMiddle.

            Instances For

              Empty differentiable graph.

              Instances For
                def Proofs.Autograd.DGraph.snoc {Γ ss : List Spec.Shape} {τ : Spec.Shape} (dg : DGraph Γ ss) (node : Node (Γ ++ ss) τ) (hn : NodeFDerivCorrect node) :
                DGraph Γ (ss ++ [τ])

                Append a node together with its NodeFDerivCorrect certificate.

                Instances For
                  def Proofs.Autograd.DGraph.castNodeContext {Γ₁ Γ₂ : List Spec.Shape} {τ : Spec.Shape} (h : Γ₁ = Γ₂) (node : Node Γ₁ τ) :
                  Node Γ₂ τ

                  Transport a node across a definitional/context-list equality.

                  This is mostly used by graph composition: the second graph sees its context as (Γ ++ ss₁) ++ ss₂, while the composed graph sees the same values as Γ ++ (ss₁ ++ ss₂).

                  Instances For
                    def Proofs.Autograd.DGraph.castNodeFDerivCorrect {Γ₁ Γ₂ : List Spec.Shape} {τ : Spec.Shape} (h : Γ₁ = Γ₂) {node : Node Γ₁ τ} (hn : NodeFDerivCorrect node) :

                    Transport a node F-derivative certificate along castNodeContext.

                    Instances For

                      Specialize an everywhere-correct graph proof to a pointwise graph proof.

                      We use this when a globally smooth block feeds a pointwise block such as LayerNorm. The first block does not need any domain hypotheses, so its GraphFDerivCorrect certificate can be read at the actual runtime point.

                      Instances For
                        noncomputable def Proofs.Autograd.DGraph.appendCore {Γ ss₁ ss₂ : List Spec.Shape} (dg₁ : DGraph Γ ss₁) (g₂ : Graph (Γ ++ ss₁) ss₂) (hg₂ : GraphFDerivCorrect g₂) :
                        DGraph Γ (ss₁ ++ ss₂)

                        Recursive implementation for append, stated over an explicit graph and proof.

                        Instances For
                          noncomputable def Proofs.Autograd.DGraph.append {Γ ss₁ ss₂ : List Spec.Shape} (dg₁ : DGraph Γ ss₁) (dg₂ : DGraph (Γ ++ ss₁) ss₂) :
                          DGraph Γ (ss₁ ++ ss₂)

                          Append a proof-carrying graph after another proof-carrying graph.

                          If dg₁ : DGraph Γ ss₁ has already computed some SSA values, then a second graph dg₂ : DGraph (Γ ++ ss₁) ss₂ may use both the original inputs and those saved values. append turns the pair into one DGraph Γ (ss₁ ++ ss₂).

                          This is the general composition adapter needed for model-level proofs: residual attention can feed LayerNorm, a recurrent cell can feed the next unrolled step, and larger blocks can be assembled while reusing the existing node-level correctness proofs.

                          Instances For
                            noncomputable def Proofs.Autograd.DGraph.weakenContextCore {Γ ss : List Spec.Shape} (extra : List Spec.Shape) (g : Graph Γ ss) (hg : GraphFDerivCorrect g) :
                            DGraph (Γ ++ extra) ss

                            Recursive implementation for weakenContext, stated over an explicit graph and proof.

                            Instances For
                              noncomputable def Proofs.Autograd.DGraph.weakenContext {Γ ss : List Spec.Shape} (dg : DGraph Γ ss) (extra : List Spec.Shape) :
                              DGraph (Γ ++ extra) ss

                              Run a proof-carrying graph while carrying extra unused inputs.

                              If dg : DGraph Γ ss, then weakenContext dg extra : DGraph (Γ ++ extra) ss evaluates the same nodes while preserving an enlarged input context. Each reused node sees the projection Γ ++ ss_so_far of the actual context (Γ ++ extra) ++ ss_so_far; gradients are inserted back by the adjoint projection, so the carried extras receive no gradient contribution from nodes that do not read them.

                              Instances For
                                theorem Proofs.Autograd.DGraph.weakenContext_backpropVec_eq_adjoint_fderiv {Γ ss : List Spec.Shape} (dg : DGraph Γ ss) (extra : List Spec.Shape) (xV : CtxVec (Γ ++ extra)) (seedV : CtxVec (Γ ++ extra ++ ss)) :

                                VJP theorem for context-weakened proof graphs.

                                The statement is intentionally direct: after threading unused inputs through the graph, the ordinary backprop = (fderiv eval)† theorem still applies. The useful content lives in weakenNodeMiddle, where unused parameters receive zero contribution by the adjoint of the drop-middle projection.

                                theorem Proofs.Autograd.DGraph.append_backpropVec_eq_adjoint_fderiv {Γ ss₁ ss₂ : List Spec.Shape} (dg₁ : DGraph Γ ss₁) (dg₂ : DGraph (Γ ++ ss₁) ss₂) (xV : CtxVec Γ) (seedV : CtxVec (Γ ++ (ss₁ ++ ss₂))) :
                                (dg₁.append dg₂).g.backpropVec xV seedV = (ContinuousLinearMap.adjoint (fderiv (dg₁.append dg₂).g.evalVec xV)) seedV

                                VJP theorem for appended proof-carrying graphs.

                                This is just Graph.backpropVec_eq_adjoint_fderiv specialized to append, but the named theorem makes model proofs read like the construction we are formalizing: prove block A, prove block B over A's extended context, append them, and immediately get the end-to-end reverse-mode theorem.

                                noncomputable def Proofs.Autograd.DGraph.snocUnaryOp {Γ ss : List Spec.Shape} {inDim outDim : } (dg : DGraph Γ ss) (idx : Idx (Γ ++ ss) (Spec.Shape.dim inDim Spec.Shape.scalar)) (C : OpSpecFDerivCorrect inDim outDim) :

                                Helper: append a unary op specified by an OpSpecFDerivCorrect proof object.

                                This is the common pattern for parameterized ops such as linear.

                                Instances For
                                  theorem Proofs.Autograd.DGraph.backpropVec_eq_adjoint_fderiv {Γ ss : List Spec.Shape} (dg : DGraph Γ ss) (xV : CtxVec Γ) (seedV : CtxVec (Γ ++ ss)) :

                                  End-to-end analytic theorem for bundled graphs.

                                  This is just Graph.backpropVec_eq_adjoint_fderiv with the bundled proof dg.hg.