TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Context

Tape-node context primitives #

This module contains the low-level vectorized context operations used by the tape-node proof library: block projections, one-hot cotangent injections, and the bridge from generic OpSpecFDerivCorrect witnesses to NodeFDerivCorrect nodes.

@[simp]
theorem Proofs.Autograd.piLpContinuousLinearEquiv2_symm_apply {n : } (f : Fin n) (i : Fin n) :
((PiLp.continuousLinearEquiv 2 fun (x : Fin n) => ).symm f).ofLp i = f i
@[simp]
theorem Proofs.Autograd.piLpContinuousLinearEquiv2_symm_clm_apply {n : } (f : Fin n) (i : Fin n) :
((PiLp.continuousLinearEquiv 2 fun (x : Fin n) => ).symm f).ofLp i = f i
@[simp]
@[simp]
theorem Proofs.Autograd.euclideanEquiv_symm_ofLp {n : } (f : Fin n) (i : Fin n) :

Raw projection from a vectorized context onto the ith block.

This is the underlying block-splitting operation; CtxVec.get below wraps it with an Idx Γ s that also remembers the expected shape.

Instances For

    Raw injection into a vectorized context: place v into block i, fill others with zeros.

    This is the adjoint of getRaw with respect to the Euclidean inner product (proved below).

    Instances For
      theorem Proofs.Autograd.CtxVec.inner_getRaw_singleRaw {Γ : List Spec.Shape} (i : Fin Γ.length) (x : CtxVec Γ) (v : Vec (Γ.get i).size) :
      inner x (singleRaw i v) = inner (getRaw i x) v

      Adjointness of raw projection/injection: ⟪x, singleRaw i v⟫ = ⟪getRaw i x, v⟫.

      This is the vectorized counterpart of the “one-hot cotangent” principle used in tape soundness.

      noncomputable def Proofs.Autograd.CtxVec.get {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (x : CtxVec Γ) :

      Project the block specified by idx : Idx Γ s out of a vectorized context.

      Instances For
        noncomputable def Proofs.Autograd.CtxVec.single {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (v : Vec s.size) :

        Inject a block into a vectorized context at idx, filling other blocks with zeros.

        Instances For
          theorem Proofs.Autograd.CtxVec.inner_get_single {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (x : CtxVec Γ) (v : Vec s.size) :
          inner x (single idx v) = inner (get idx x) v

          Adjointness of get/single: ⟪x, single idx v⟫ = ⟪get idx x, v⟫.

          Continuous linear map extracting the head block of a nonempty vectorized context.

          Instances For
            @[simp]
            theorem Proofs.Autograd.CtxVec.headCLM_apply {s : Spec.Shape} {ss : List Spec.Shape} (x : CtxVec (s :: ss)) (j : Fin s.size) :

            Continuous linear map extracting the tail blocks of a nonempty vectorized context.

            Instances For
              @[simp]
              theorem Proofs.Autograd.CtxVec.tailCLM_apply {s : Spec.Shape} {ss : List Spec.Shape} (x : CtxVec (s :: ss)) (j : Fin (ctxSize ss)) :

              getRaw packaged as a continuous linear map (constructed by recursion with headCLM/tailCLM).

              Instances For
                @[simp]
                noncomputable def Proofs.Autograd.CtxVec.getCLM {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                get packaged as a continuous linear map.

                Instances For
                  @[simp]
                  theorem Proofs.Autograd.CtxVec.getCLM_apply {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (x : CtxVec Γ) :
                  (getCLM idx) x = get idx x

                  Nodes in this file are authored directly on the vectorized context CtxVec.

                  This is the most convenient authoring style for analytic proofs: forwardVec/jvpVec/vjpVec are definitional, and the correctness obligation is an inner-product identity on Euclidean vectors.

                  noncomputable def Proofs.Autograd.Node.ofVec {Γ : List Spec.Shape} {τ : Spec.Shape} (f : CtxVec ΓVec τ.size) (jvp : CtxVec ΓCtxVec ΓVec τ.size) (vjp : CtxVec ΓVec τ.sizeCtxVec Γ) (correct_inner : ∀ (x dx : CtxVec Γ) (δ : Vec τ.size), inner (jvp x dx) δ = inner dx (vjp x δ)) :
                  Node Γ τ

                  Convenience constructor: build a tape Node from vector-level forward/JVP/VJP plus adjointness.

                  The correct_inner field is exactly the local VJP/JVP law: ⟪jvp x dx, δ⟫ = ⟪dx, vjp x δ⟫.

                  Instances For
                    @[simp]
                    theorem Proofs.Autograd.Node.forwardVec_ofVec {Γ : List Spec.Shape} {τ : Spec.Shape} (f : CtxVec ΓVec τ.size) (jvp : CtxVec ΓCtxVec ΓVec τ.size) (vjp : CtxVec ΓVec τ.sizeCtxVec Γ) (h : ∀ (x dx : CtxVec Γ) (δ : Vec τ.size), inner (jvp x dx) δ = inner dx (vjp x δ)) :
                    (ofVec f jvp vjp h).forwardVec = f
                    @[simp]
                    theorem Proofs.Autograd.Node.jvpVec_ofVec {Γ : List Spec.Shape} {τ : Spec.Shape} (f : CtxVec ΓVec τ.size) (jvp : CtxVec ΓCtxVec ΓVec τ.size) (vjp : CtxVec ΓVec τ.sizeCtxVec Γ) (h : ∀ (x dx : CtxVec Γ) (δ : Vec τ.size), inner (jvp x dx) δ = inner dx (vjp x δ)) :
                    (ofVec f jvp vjp h).jvpVec = jvp
                    @[simp]
                    theorem Proofs.Autograd.Node.vjpVec_ofVec {Γ : List Spec.Shape} {τ : Spec.Shape} (f : CtxVec ΓVec τ.size) (jvp : CtxVec ΓCtxVec ΓVec τ.size) (vjp : CtxVec ΓVec τ.sizeCtxVec Γ) (h : ∀ (x dx : CtxVec Γ) (δ : Vec τ.size), inner (jvp x dx) δ = inner dx (vjp x δ)) :
                    (ofVec f jvp vjp h).vjpVec = vjp
                    noncomputable def Proofs.Autograd.OpSpecFDerivCorrect.linear {inDim outDim : } (m : Spec.LinearSpec inDim outDim) :
                    OpSpecFDerivCorrect inDim outDim

                    OpSpecFDerivCorrect instance for a linear layer.

                    This is the analytic correctness lemma behind the tape node constructors: it identifies the JVP with the Fréchet derivative (a matrix multiplication) for linear_spec.

                    PyTorch analogue: torch.nn.Linear forward map is affine, derivative is constant. https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

                    Instances For