TorchLean API

NN.Proofs.Autograd.Tape.Ops.Transformer.PostNorm

Post-Norm Transformer Sublayers #

This file packages the LayerNorm theorem at the interface used by post-norm Transformer blocks.

The preceding files prove the smooth residual components:

This module proves the next runtime-facing boundary:

residual_stream ↦ LayerNorm(residual_stream, gamma, beta).

That is the exact post-norm sublayer shape used by classical Transformer encoder blocks (LayerNorm(x + Sublayer(x))). We deliberately keep this proof factored at the residual-stream interface. It avoids treating LayerNorm's pointwise domain hypotheses as globally smooth, and it gives later full-block proofs a clean seam: compose a globally smooth residual graph with this pointwise post-norm graph once the context-threading adapter for unused parameters is in place.

References:

@[reducible, inline]

A post-norm Transformer stage normalizes a sequence-shaped residual stream.

Instances For
    @[reducible, inline]

    Saved tensors for the LayerNorm part of a post-norm Transformer stage.

    Instances For
      @[reducible, inline]
      abbrev Proofs.Autograd.Transformer.ΓMHAWithNorm (seqLen dModel numHeads headDim : ) :

      MHA context extended with the affine parameters for the following LayerNorm.

      Instances For
        @[reducible, inline]
        abbrev Proofs.Autograd.Transformer.ssMHAWithPostNorm (seqLen dModel numHeads headDim : ) :

        Residual-MHA plus one whole-node post-norm output.

        Instances For
          def Proofs.Autograd.Transformer.idxMhaPostNormGamma {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
          Idx (ΓMHAWithNorm seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

          Gamma parameter for the post-norm LayerNorm, weakened through any saved tensors.

          Instances For
            def Proofs.Autograd.Transformer.idxMhaPostNormBeta {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
            Idx (ΓMHAWithNorm seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

            Beta parameter for the post-norm LayerNorm, weakened through any saved tensors.

            Instances For
              def Proofs.Autograd.Transformer.idxMhaResidualForPostNorm {seqLen dModel numHeads headDim : } :
              Idx (ΓMHAWithNorm seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim) (LayerNorm.MatShape seqLen dModel)

              The residual stream x + MHA(x) after the residual-attention prefix graph.

              Instances For
                def Proofs.Autograd.Transformer.mhaPostNormInputs {seqLen dModel numHeads headDim : } :
                LayerNorm.Inputs (ΓMHAWithNorm seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim) seqLen dModel

                LayerNorm input triple after the residual-MHA prefix has run.

                Instances For
                  noncomputable def Proofs.Autograd.Transformer.mhaResidualWithNormParamsDGraph {seqLen dModel numHeads headDim : } (c : ) :
                  DGraph (ΓMHAWithNorm seqLen dModel numHeads headDim) (ssMHAResidual seqLen dModel numHeads headDim)

                  Residual-MHA graph while carrying the following LayerNorm's affine parameters.

                  The carried gamma/beta are not read by attention; DGraph.weakenContext ensures their gradients from the attention prefix are zero, while still keeping them available to the appended LayerNorm node.

                  Instances For
                    noncomputable def Proofs.Autograd.Transformer.mhaPostNormGraph {seqLen dModel numHeads headDim : } (c ε : ) :
                    Graph (ΓMHAWithNorm seqLen dModel numHeads headDim) (ssMHAWithPostNorm seqLen dModel numHeads headDim)

                    Single SSA graph for the first post-norm Transformer encoder sublayer:

                    LayerNorm(x + MultiHeadSelfAttention(x), gamma, beta).

                    LayerNorm is appended as a whole pointwise node backed by the detailed LayerNorm graph theorem.

                    Instances For

                      Pointwise correctness for the single-graph residual-MHA post-norm sublayer.

                      Instances For

                        End-to-end VJP theorem for the single-graph residual-MHA post-norm sublayer.

                        Sequence feed-forward plus post-norm #

                        This is the second sublayer shape in a post-norm Transformer encoder block:

                        LayerNorm(X + FFN(X), gamma, beta).

                        The FFN residual is sequence-shaped, not merely one-token-shaped. Its affine maps are supplied over flattened sequence tensors, so a shared position-wise implementation or a fused backend can both instantiate the theorem by exposing their fixed affine maps.

                        @[reducible, inline]

                        Sequence-FFN context extended with the affine parameters for the following LayerNorm.

                        Instances For
                          @[reducible, inline]

                          Sequence-FFN residual plus one whole-node post-norm output.

                          Instances For

                            Gamma parameter for the FFN post-norm LayerNorm.

                            Instances For

                              Beta parameter for the FFN post-norm LayerNorm.

                              Instances For
                                def Proofs.Autograd.Transformer.idxSeqFfnResidualForPostNorm {seqLen dModel dFF : } :
                                Idx (ΓSeqFFNWithNorm seqLen dModel ++ ssSeqFFNResidual seqLen dModel dFF) (LayerNorm.MatShape seqLen dModel)

                                The residual stream X + FFN(X) after the sequence-FFN prefix graph.

                                Instances For
                                  def Proofs.Autograd.Transformer.seqFfnPostNormInputs {seqLen dModel dFF : } :
                                  LayerNorm.Inputs (ΓSeqFFNWithNorm seqLen dModel ++ ssSeqFFNResidual seqLen dModel dFF) seqLen dModel

                                  LayerNorm input triple after the residual-FFN prefix has run.

                                  Instances For
                                    noncomputable def Proofs.Autograd.Transformer.seqFfnResidualWithNormParamsDGraph {seqLen dModel dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) :
                                    DGraph (ΓSeqFFNWithNorm seqLen dModel) (ssSeqFFNResidual seqLen dModel dFF)

                                    Sequence-FFN graph while carrying the following LayerNorm's affine parameters.

                                    Instances For
                                      noncomputable def Proofs.Autograd.Transformer.seqFfnPostNormGraph {seqLen dModel dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (ε : ) :
                                      Graph (ΓSeqFFNWithNorm seqLen dModel) (ssSeqFFNWithPostNorm seqLen dModel dFF)

                                      Single SSA graph for the second post-norm Transformer encoder sublayer:

                                      LayerNorm(X + FFN(X), gamma, beta).

                                      Instances For

                                        Pointwise correctness for the single-graph residual-FFN post-norm sublayer.

                                        Instances For

                                          End-to-end VJP theorem for the single-graph residual-FFN post-norm sublayer.

                                          noncomputable def Proofs.Autograd.Transformer.postNormGraph {seqLen dModel : } (ε : ) :
                                          Graph (ΓPostNorm seqLen dModel) (ssPostNorm seqLen dModel)

                                          The post-norm graph itself.

                                          The context is [residual_stream, gamma, beta]. For attention this residual stream is x + MHA(x); for the feed-forward half it is x + FFN(x). The residual computation is proved in ResidualAttention/FeedForward; this graph proves the LayerNorm boundary that follows it.

                                          Instances For
                                            noncomputable def Proofs.Autograd.Transformer.postNormGraphFderivCorrectAt {seqLen dModel : } (ε : ) (xV : CtxVec (ΓPostNorm seqLen dModel)) (hVarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε).evalVec xV)).ofLp i 0) :

                                            Pointwise correctness for the post-norm Transformer boundary.

                                            The two hypotheses are exactly LayerNorm's differentiability side conditions at the runtime point: the variance-plus-epsilon branch is positive, and the standard deviation denominator is nonzero.

                                            Instances For
                                              theorem Proofs.Autograd.Transformer.postNorm_backpropVec_eq_adjoint_fderiv_at {seqLen dModel : } (ε : ) (xV : CtxVec (ΓPostNorm seqLen dModel)) (seedV : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel)) (hVarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε).evalVec xV)).ofLp i 0) :

                                              VJP theorem for the post-norm Transformer boundary.

                                              This is the model-level theorem used after either residual attention or a residual feed-forward block has produced its sequence-shaped residual stream.

                                              theorem Proofs.Autograd.Transformer.residualThenPostNorm_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {seqLen dModel : } (ε : ) (residualPack : ECtxVec (ΓPostNorm seqLen dModel)) (DresidualPack : E →L[] CtxVec (ΓPostNorm seqLen dModel)) (x : E) (hResidual : HasFDerivAt residualPack DresidualPack x) (hVarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε).evalVec (residualPack x))).ofLp i) (hStdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε).evalVec (residualPack x))).ofLp i 0) :
                                              HasFDerivAt (fun (z : E) => (postNormGraph ε).evalVec (residualPack z)) ((fderiv (postNormGraph ε).evalVec (residualPack x)).comp DresidualPack) x

                                              Calculus bridge for a residual block followed by post-norm LayerNorm.

                                              This is the theorem we use to move from separately proved pieces to a whole Transformer sublayer. Suppose some residual-producing map

                                              residualPack : E → [residual_stream, gamma, beta]

                                              is differentiable at x. It may come from residual attention, residual feed-forward, or any future block that produces the same LayerNorm context. If the LayerNorm domain hypotheses hold at residualPack x, then the composed post-norm map

                                              x ↦ LayerNorm(residualPack x)

                                              is differentiable, with derivative given by the usual chain rule.

                                              This is intentionally more general than MHA: the same theorem covers Transformer, ViT, GPT-style blocks, and future residual modules once they expose the residual stream plus affine LayerNorm parameters.

                                              theorem Proofs.Autograd.Transformer.twoSublayerPostNormBlock_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {seqLen dModel : } (ε₁ ε₂ : ) (attnPack : ECtxVec (ΓPostNorm seqLen dModel)) (DattnPack : E →L[] CtxVec (ΓPostNorm seqLen dModel)) (ffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel)CtxVec (ΓPostNorm seqLen dModel)) (DffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel) →L[] CtxVec (ΓPostNorm seqLen dModel)) (x : E) (hAttnPack : HasFDerivAt attnPack DattnPack x) (hNorm1VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₁).evalVec (attnPack x))).ofLp i) (hNorm1StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₁).evalVec (attnPack x))).ofLp i 0) (hFfnPack : HasFDerivAt ffnPack DffnPack ((postNormGraph ε₁).evalVec (attnPack x))) (hNorm2VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x))))).ofLp i) (hNorm2StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x))))).ofLp i 0) :
                                              HasFDerivAt (fun (z : E) => (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack z)))) ((fderiv (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x)))).comp (DffnPack.comp ((fderiv (postNormGraph ε₁).evalVec (attnPack x)).comp DattnPack))) x

                                              Calculus bridge for a full two-sublayer post-norm Transformer encoder block.

                                              This theorem is deliberately stated at the map level rather than for one concrete SSA graph. A post-norm encoder block has two domain-sensitive LayerNorms:

                                              1. attnPack builds [x + MHA(x), gamma₁, beta₁];
                                              2. after the first LayerNorm evaluation, ffnPack builds [norm₁ + FFN(norm₁), gamma₂, beta₂];
                                              3. the second LayerNorm produces the block output.

                                              The theorem says that if the two residual-pack maps are differentiable and both LayerNorm calls satisfy their local denominator hypotheses, then the whole two-sublayer block is differentiable by ordinary Fréchet chain rule.

                                              The concrete graph-level VJP theorems for each sublayer are mhaPostNorm_* and seqFfnPostNorm_*. This bridge is the public mathematical composition point for Transformer, ViT, and GPT-style post-norm blocks while the final monolithic SSA graph is assembled.

                                              Named theorem for the post-normalized residual-attention interface.

                                              The residual attention graph proves production of the first input in this context: residual_stream = x + MHA(x). This theorem proves the LayerNorm pass once that residual stream is the current tensor.

                                              Named theorem for the post-normalized residual feed-forward interface.

                                              The position-wise FFN proof establishes the smooth residual update. This theorem is the common LayerNorm boundary used after that update in post-norm encoder blocks.