TorchLean API

NN.Proofs.Autograd.Tape.Ops.Transformer.EncoderBlock

Post-Norm Transformer Encoder Block #

This module names the full two-sublayer post-norm encoder-block theorem.

The theorem is stated at the block boundary:

x ↦ LN₂(FFN(LN₁(MHA(x) + x)) + LN₁(MHA(x) + x)).

The two sublayers already have graph-level VJP theorems in PostNorm.lean. The theorem here is the composition result that a model proof imports when it needs the whole encoder block as one differentiable map. A future lowering pass can still build one concrete SSA graph for the exact runtime layout; this theorem is the mathematical block contract that such a graph must implement.

Concrete SSA encoder block #

The definitions below assemble one executable-style proof graph for a post-norm encoder block. The context is:

[x, Wq, Wk, Wv, Wo, gamma₁, beta₁, gamma₂, beta₂].

The FFN affine maps are supplied as fixed sequence-level linear maps, matching the interface used by seqFfnResidualDGraph.

@[reducible, inline]
abbrev Proofs.Autograd.Transformer.ΓEncoderBlock (seqLen dModel numHeads headDim : ) :

Full encoder-block context: MHA parameters, first LayerNorm affine parameters, second LayerNorm affine parameters.

Instances For
    @[reducible, inline]
    abbrev Proofs.Autograd.Transformer.ssEncoderBlock (seqLen dModel numHeads headDim dFF : ) :

    Saved tensors for one post-norm encoder block.

    Instances For
      def Proofs.Autograd.Transformer.idxEncoderNorm1Gamma {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
      Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

      First LayerNorm scale parameter in the full encoder-block context.

      Instances For
        def Proofs.Autograd.Transformer.idxEncoderNorm1Beta {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
        Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

        First LayerNorm shift parameter in the full encoder-block context.

        Instances For
          def Proofs.Autograd.Transformer.idxEncoderNorm2Gamma {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
          Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

          Second LayerNorm scale parameter in the full encoder-block context.

          Instances For
            def Proofs.Autograd.Transformer.idxEncoderNorm2Beta {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
            Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

            Second LayerNorm shift parameter in the full encoder-block context.

            Instances For
              noncomputable def Proofs.Autograd.Transformer.encoderMhaResidualDGraph {seqLen dModel numHeads headDim : } (c : ) :
              DGraph (ΓEncoderBlock seqLen dModel numHeads headDim) (ssMHAResidual seqLen dModel numHeads headDim)

              Residual-attention prefix while carrying both LayerNorm parameter pairs.

              Instances For
                def Proofs.Autograd.Transformer.idxEncoderMhaResidual {seqLen dModel numHeads headDim : } :
                Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim) (LayerNorm.MatShape seqLen dModel)

                Residual stream x + MHA(x) after the attention prefix.

                Instances For
                  def Proofs.Autograd.Transformer.encoderNorm1Inputs {seqLen dModel numHeads headDim : } :
                  LayerNorm.Inputs (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim) seqLen dModel

                  First LayerNorm input triple in the concrete encoder block.

                  Instances For
                    noncomputable def Proofs.Autograd.Transformer.encoderAfterNorm1Graph {seqLen dModel numHeads headDim : } (c ε₁ : ) :
                    Graph (ΓEncoderBlock seqLen dModel numHeads headDim) (ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel])

                    Encoder graph through the first post-norm attention sublayer.

                    Instances For
                      def Proofs.Autograd.Transformer.idxEncoderNorm1Out {seqLen dModel numHeads headDim : } :
                      Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                      First post-norm sublayer output.

                      Instances For
                        def Proofs.Autograd.Transformer.idxEncoderNorm1OutAfterFfn {seqLen dModel numHeads headDim dFF : } :
                        Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF, SeqFFNModelShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                        First post-norm output weakened through the FFN intermediates.

                        Instances For
                          def Proofs.Autograd.Transformer.idxEncoderFfnHiddenPre {seqLen dModel numHeads headDim dFF : } :
                          Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                          First FFN affine output in the full encoder graph.

                          Instances For
                            def Proofs.Autograd.Transformer.idxEncoderFfnHiddenAct {seqLen dModel numHeads headDim dFF : } :
                            Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                            FFN activation output in the full encoder graph.

                            Instances For
                              def Proofs.Autograd.Transformer.idxEncoderFfnProjected {seqLen dModel numHeads headDim dFF : } :
                              Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF, SeqFFNModelShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                              FFN projection output in the full encoder graph.

                              Instances For
                                def Proofs.Autograd.Transformer.idxEncoderFfnResidual {seqLen dModel numHeads headDim dFF : } :
                                Idx (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF) (SeqFFNModelShape seqLen dModel)

                                FFN residual output after the second sublayer residual add.

                                Instances For
                                  noncomputable def Proofs.Autograd.Transformer.encoderFfnResidualGraph {seqLen dModel numHeads headDim dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ : ) :
                                  Graph (ΓEncoderBlock seqLen dModel numHeads headDim) (ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF)

                                  Graph through the FFN residual part after the first LayerNorm.

                                  Instances For
                                    def Proofs.Autograd.Transformer.encoderNorm2Inputs {seqLen dModel numHeads headDim dFF : } :
                                    LayerNorm.Inputs (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssMHAResidual seqLen dModel numHeads headDim ++ [LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF) seqLen dModel

                                    Second LayerNorm input triple in the concrete encoder block.

                                    Instances For
                                      noncomputable def Proofs.Autograd.Transformer.encoderBlockGraph {seqLen dModel numHeads headDim dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) :
                                      Graph (ΓEncoderBlock seqLen dModel numHeads headDim) (ssEncoderBlock seqLen dModel numHeads headDim dFF)

                                      Concrete SSA graph for one full post-norm Transformer encoder block.

                                      Instances For
                                        noncomputable def Proofs.Autograd.Transformer.encoderFfnResidualGraphFDerivCorrectAt {seqLen dModel numHeads headDim dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ : ) (xV : CtxVec (ΓEncoderBlock seqLen dModel numHeads headDim)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode encoderNorm1Inputs ε₁) ((encoderMhaResidualDGraph c).g.evalVec xV)) :

                                        Pointwise analytic correctness for the graph through the FFN residual.

                                        Instances For
                                          noncomputable def Proofs.Autograd.Transformer.encoderBlockGraphFDerivCorrectAt {seqLen dModel numHeads headDim dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) (xV : CtxVec (ΓEncoderBlock seqLen dModel numHeads headDim)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode encoderNorm1Inputs ε₁) ((encoderMhaResidualDGraph c).g.evalVec xV)) (hNorm2 : NodeFDerivCorrectAt (LayerNorm.wholeNode encoderNorm2Inputs ε₂) ((encoderFfnResidualGraph fc1 b1 fc2 b2 c ε₁).evalVec xV)) :
                                          GraphFDerivCorrectAt (encoderBlockGraph fc1 b1 fc2 b2 c ε₁ ε₂) xV

                                          Pointwise analytic correctness for the complete concrete encoder-block graph.

                                          Instances For
                                            theorem Proofs.Autograd.Transformer.encoderBlock_backpropVec_eq_adjoint_fderiv_at {seqLen dModel numHeads headDim dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) (xV : CtxVec (ΓEncoderBlock seqLen dModel numHeads headDim)) (seedV : CtxVec (ΓEncoderBlock seqLen dModel numHeads headDim ++ ssEncoderBlock seqLen dModel numHeads headDim dFF)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode encoderNorm1Inputs ε₁) ((encoderMhaResidualDGraph c).g.evalVec xV)) (hNorm2 : NodeFDerivCorrectAt (LayerNorm.wholeNode encoderNorm2Inputs ε₂) ((encoderFfnResidualGraph fc1 b1 fc2 b2 c ε₁).evalVec xV)) :
                                            (encoderBlockGraph fc1 b1 fc2 b2 c ε₁ ε₂).backpropVec xV seedV = (ContinuousLinearMap.adjoint (fderiv (encoderBlockGraph fc1 b1 fc2 b2 c ε₁ ε₂).evalVec xV)) seedV

                                            End-to-end VJP theorem for the concrete post-norm Transformer encoder-block graph.

                                            theorem Proofs.Autograd.Transformer.postNormEncoderBlock_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {seqLen dModel : } (ε₁ ε₂ : ) (attnPack : ECtxVec (ΓPostNorm seqLen dModel)) (DattnPack : E →L[] CtxVec (ΓPostNorm seqLen dModel)) (ffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel)CtxVec (ΓPostNorm seqLen dModel)) (DffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel) →L[] CtxVec (ΓPostNorm seqLen dModel)) (x : E) (hAttnPack : HasFDerivAt attnPack DattnPack x) (hNorm1VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₁).evalVec (attnPack x))).ofLp i) (hNorm1StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₁).evalVec (attnPack x))).ofLp i 0) (hFfnPack : HasFDerivAt ffnPack DffnPack ((postNormGraph ε₁).evalVec (attnPack x))) (hNorm2VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x))))).ofLp i) (hNorm2StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x))))).ofLp i 0) :
                                            HasFDerivAt (fun (z : E) => (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack z)))) (fderiv (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (attnPack x))) ∘SL DffnPack ∘SL fderiv (postNormGraph ε₁).evalVec (attnPack x) ∘SL DattnPack) x

                                            Fréchet differentiability of a complete post-norm Transformer encoder block.

                                            attnPack builds the first LayerNorm input triple from the outer model context. ffnPack builds the second LayerNorm input triple after the first post-norm sublayer has evaluated. The LayerNorm side conditions are local to the two concrete normalization calls.