TorchLean API

NN.Proofs.Autograd.Tape.Ops.Norm.LayerNorm

LayerNorm #

Pointwise analytic correctness for a LayerNorm graph.

This is spec-level over . It is the proof-tape counterpart of the runtime/spec LayerNorm in Spec.layerNorm: a seqLen × embedDim tensor is normalized across the last axis, the row-wise normalizer is broadcast back over each token, and affine parameters gamma/beta are broadcast over the sequence dimension. The runtime API and compiled IR path both route through that spec definition; this file proves the corresponding reverse-mode graph rule.

Because the proof graph uses the differentiable scalar nodes sqrt (max x 0) and inv, the main theorem is pointwise (GraphFDerivCorrectAt) with explicit domain assumptions. Those hypotheses are the honest mathematical boundary: away from the clamp kink and zero denominator, backprop is the adjoint of the Fréchet derivative. The executable Spec.layerNorm additionally clamps the raw variance before adding epsilon as a numerical guard; over exact real variance this is the same contract on the positive branch used by the proof.

PyTorch correspondence / citations #

@[reducible, inline]

Matrix shape m×n.

Instances For
    @[reducible, inline]

    Vector shape k.

    Instances For
      @[reducible, inline]

      Input context shapes: [X, gamma, beta] for layer norm over the last axis.

      Instances For
        @[reducible, inline]

        First 6 intermediates in the LayerNorm computation (up to var_eps).

        Instances For
          @[reducible, inline]

          Prefix intermediates up to std (adds one more vector).

          Instances For
            @[reducible, inline]

            Full list of intermediates for the LayerNorm graph in this file.

            Instances For

              Index of the input matrix X in the base LayerNorm context ΓLN m n ++ ss.

              Instances For

                Index of the scale vector gamma in the base LayerNorm context ΓLN m n ++ ss.

                Instances For

                  Index of the shift vector beta in the base LayerNorm context ΓLN m n ++ ss.

                  Instances For

                    Index helper for the last element of an extended context Γ ++ ss ++ [τ].

                    Instances For
                      noncomputable def Proofs.Autograd.LayerNorm.nodeMean {m n : } :
                      Node (ΓLN m n) (VecShape m)

                      Mean over the last axis: mean : ℝ^{m×n} → ℝ^{m}.

                      Instances For
                        noncomputable def Proofs.Autograd.LayerNorm.g1 {m n : } :

                        Graph prefix producing [mean].

                        Instances For

                          Index of mean in the extended context ΓLN ++ [mean].

                          Instances For
                            noncomputable def Proofs.Autograd.LayerNorm.nodeMeanB {m n : } :

                            Broadcast mean back to m×n (row-wise).

                            Instances For
                              noncomputable def Proofs.Autograd.LayerNorm.g2 {m n : } :

                              Graph prefix producing [mean, mean_b].

                              Instances For

                                Index of mean_b in ΓLN ++ [mean, mean_b].

                                Instances For
                                  noncomputable def Proofs.Autograd.LayerNorm.nodeCentered {m n : } :

                                  Center: centered := X - mean_b.

                                  Instances For
                                    noncomputable def Proofs.Autograd.LayerNorm.g3 {m n : } :

                                    Graph prefix producing [mean, mean_b, centered].

                                    Instances For

                                      Index of centered in ΓLN ++ [mean, mean_b, centered].

                                      Instances For

                                        Square centered: centered_sq := centered ⊙ centered.

                                        Instances For
                                          noncomputable def Proofs.Autograd.LayerNorm.g4 {m n : } :

                                          Graph prefix producing [mean, mean_b, centered, centered_sq].

                                          Instances For

                                            Index of centered_sq in the extended context.

                                            Instances For

                                              Variance per row: var := mean(centered_sq) producing a length-m vector.

                                              Instances For

                                                Graph prefix producing [mean, mean_b, centered, centered_sq, var].

                                                Instances For

                                                  Index of var in the extended context.

                                                  Instances For
                                                    noncomputable def Proofs.Autograd.LayerNorm.nodeVarEps {m n : } (ε : ) :

                                                    Add epsilon: var_eps := var + ε.

                                                    Instances For
                                                      noncomputable def Proofs.Autograd.LayerNorm.layerNormPrefix6 {m n : } (ε : ) :
                                                      Graph (ΓLN m n) (ssPrefix6 m n)

                                                      Graph prefix computing the first 6 intermediates (ssPrefix6).

                                                      Instances For

                                                        Index of var_eps in ΓLN ++ ssPrefix6.

                                                        Instances For
                                                          noncomputable def Proofs.Autograd.LayerNorm.nodeStd {m n : } :
                                                          Node (ΓLN m n ++ ssPrefix6 m n) (VecShape m)

                                                          Standard deviation: std := sqrt_clamp(var_eps).

                                                          This is where the development becomes pointwise: differentiability depends on the (clamped) input.

                                                          Instances For
                                                            noncomputable def Proofs.Autograd.LayerNorm.layerNormPrefix7 {m n : } (ε : ) :
                                                            Graph (ΓLN m n) (ssPrefix7 m n)

                                                            Graph prefix computing ssPrefix7 (adds std).

                                                            Instances For

                                                              Index of std in ΓLN ++ ssPrefix7.

                                                              Instances For
                                                                noncomputable def Proofs.Autograd.LayerNorm.nodeInvStd {m n : } :
                                                                Node (ΓLN m n ++ ssPrefix7 m n) (VecShape m)

                                                                Inverse standard deviation: inv_std := 1/std.

                                                                Instances For
                                                                  noncomputable def Proofs.Autograd.LayerNorm.g8 {m n : } (ε : ) :

                                                                  Graph prefix adding inv_std.

                                                                  Instances For

                                                                    Index of inv_std in the extended context.

                                                                    Instances For
                                                                      noncomputable def Proofs.Autograd.LayerNorm.nodeInvStdB {m n : } :
                                                                      Node (ΓLN m n ++ (ssPrefix7 m n ++ [VecShape m])) (MatShape m n)

                                                                      Broadcast inv_std back to m×n (row-wise).

                                                                      Instances For
                                                                        noncomputable def Proofs.Autograd.LayerNorm.g9 {m n : } (ε : ) :

                                                                        Graph prefix adding inv_std_b.

                                                                        Instances For

                                                                          Index of centered in the stage-g9 context.

                                                                          Instances For

                                                                            Index of inv_std_b in the stage-g9 context.

                                                                            Instances For
                                                                              noncomputable def Proofs.Autograd.LayerNorm.nodeNorm {m n : } :

                                                                              Node computing normalized := centered ⊙ inv_std_b.

                                                                              Instances For
                                                                                noncomputable def Proofs.Autograd.LayerNorm.g10 {m n : } (ε : ) :

                                                                                Graph prefix producing normalized := centered ⊙ inv_std_b.

                                                                                Instances For
                                                                                  noncomputable def Proofs.Autograd.LayerNorm.nodeGammaB {m n : } :

                                                                                  Broadcast gamma to m×n (column-wise).

                                                                                  Instances For
                                                                                    noncomputable def Proofs.Autograd.LayerNorm.g11 {m n : } (ε : ) :

                                                                                    Graph prefix adding gamma_b.

                                                                                    Instances For

                                                                                      Index of normalized in the context at stage g11.

                                                                                      Instances For

                                                                                        Index of gamma_b at stage g11.

                                                                                        Instances For
                                                                                          noncomputable def Proofs.Autograd.LayerNorm.nodeScaled {m n : } :

                                                                                          Scale: scaled := normalized ⊙ gamma_b.

                                                                                          Instances For
                                                                                            noncomputable def Proofs.Autograd.LayerNorm.g12 {m n : } (ε : ) :

                                                                                            Graph prefix adding scaled.

                                                                                            Instances For

                                                                                              Broadcast beta to m×n (column-wise).

                                                                                              Instances For
                                                                                                noncomputable def Proofs.Autograd.LayerNorm.g13 {m n : } (ε : ) :

                                                                                                Graph prefix adding beta_b.

                                                                                                Instances For

                                                                                                  Index of scaled at stage g13.

                                                                                                  Instances For

                                                                                                    Index of beta_b at stage g13.

                                                                                                    Instances For
                                                                                                      noncomputable def Proofs.Autograd.LayerNorm.nodeY {m n : } :

                                                                                                      Output: y := scaled + beta_b.

                                                                                                      Instances For
                                                                                                        noncomputable def Proofs.Autograd.LayerNorm.layerNormGraph {m n : } (ε : ) :
                                                                                                        Graph (ΓLN m n) (ssLayerNorm m n)

                                                                                                        Full LayerNorm graph (as an explicit snoc chain).

                                                                                                        Instances For
                                                                                                          noncomputable def Proofs.Autograd.LayerNorm.layerNormGraphFderivCorrectAt {m n : } (ε : ) (xV : CtxVec (ΓLN m n)) (hVarEpsPos : ∀ (i : Fin (VecShape m).size), 0 < (CtxVec.get idxVarEps ((layerNormPrefix6 ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (VecShape m).size), (CtxVec.get idxStd ((layerNormPrefix7 ε).evalVec xV)).ofLp i 0) :

                                                                                                          Pointwise proof that layerNormGraph satisfies GraphFDerivCorrectAt.

                                                                                                          The hypotheses hVarEpsPos and hStdNe0 are explicit domain assumptions ensuring that sqrt and inv are differentiable at the execution point.

                                                                                                          Instances For
                                                                                                            theorem Proofs.Autograd.LayerNorm.backprop_eq_adjoint_fderiv_layerNorm_at {m n : } (ε : ) (xV : CtxVec (ΓLN m n)) (seedV : CtxVec (ΓLN m n ++ ssLayerNorm m n)) (hVarEpsPos : ∀ (i : Fin (VecShape m).size), 0 < (CtxVec.get idxVarEps ((layerNormPrefix6 ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (VecShape m).size), (CtxVec.get idxStd ((layerNormPrefix7 ε).evalVec xV)).ofLp i 0) :

                                                                                                            Pointwise end-to-end result: backprop equals (fderiv eval)† for layerNormGraph.

                                                                                                            The hypotheses hVarEpsPos and hStdNe0 are the explicit domain assumptions needed for differentiability of sqrt (after clamp) and inv at the actual execution point.

                                                                                                            LayerNorm inputs inside an arbitrary tape context.

                                                                                                            This is the model-level interface we use once LayerNorm is no longer the root graph. For example, in a post-norm Transformer block, x is the residual stream produced by an earlier SSA node, while gamma and beta are carried parameters in the surrounding context.

                                                                                                            • x : Idx Γ (MatShape m n)

                                                                                                              Sequence/residual matrix normalized across its last axis.

                                                                                                            • gamma : Idx Γ (VecShape n)

                                                                                                              Affine scale vector.

                                                                                                            • beta : Idx Γ (VecShape n)

                                                                                                              Affine shift vector.

                                                                                                            Instances For
                                                                                                              @[reducible, inline]

                                                                                                              Saved tensors before the final LayerNorm output y.

                                                                                                              Instances For

                                                                                                                Index of the final LayerNorm output in ΓLN ++ ssLayerNorm.

                                                                                                                Instances For
                                                                                                                  noncomputable def Proofs.Autograd.LayerNorm.packInputsCLM {Γ : List Spec.Shape} {m n : } (inputs : Inputs Γ m n) :

                                                                                                                  Linear map that packs arbitrary-context LayerNorm inputs into the canonical context [X, gamma, beta].

                                                                                                                  Instances For

                                                                                                                    Project the final LayerNorm output from the full canonical graph context.

                                                                                                                    Instances For
                                                                                                                      noncomputable def Proofs.Autograd.LayerNorm.wholeNode {Γ : List Spec.Shape} {m n : } (inputs : Inputs Γ m n) (ε : ) :
                                                                                                                      Node Γ (MatShape m n)

                                                                                                                      LayerNorm as one reusable pointwise node over arbitrary context indices.

                                                                                                                      Internally this node runs the already-proved detailed LayerNorm graph. Its JVP is defined as the Fréchet derivative of that composed map at the current point, and its VJP is the adjoint of that derivative. This is exactly the block-level abstraction needed for large model proofs: the detailed LayerNorm proof remains in this file, while Transformer/GPT/ViT proofs can treat LayerNorm as a single pointwise node with explicit domain assumptions.

                                                                                                                      Instances For
                                                                                                                        noncomputable def Proofs.Autograd.LayerNorm.wholeNodeFDerivCorrectAt {Γ : List Spec.Shape} {m n : } (inputs : Inputs Γ m n) (ε : ) (xV : CtxVec Γ) (hVarEpsPos : ∀ (i : Fin (VecShape m).size), 0 < (CtxVec.get idxVarEps ((layerNormPrefix6 ε).evalVec ((packInputsCLM inputs) xV))).ofLp i) (hStdNe0 : ∀ (i : Fin (VecShape m).size), (CtxVec.get idxStd ((layerNormPrefix7 ε).evalVec ((packInputsCLM inputs) xV))).ofLp i 0) :

                                                                                                                        Pointwise derivative certificate for wholeNode.

                                                                                                                        The hypotheses are the same LayerNorm domain conditions as the detailed graph theorem, but evaluated after packing the arbitrary context into [X, gamma, beta].

                                                                                                                        Instances For