TorchLean API

NN.Proofs.Autograd.Tape.Ops.Transformer.DecoderBlock

GPT-Style Decoder Block #

This module packages the post-norm GPT decoder-block composition theorem. The masked attention front half is supplied as a differentiable residual-pack map; the concrete finite-mask attention core and its projection/merge composition theorem live in NN.Proofs.Autograd.Tape.Ops.Attention.MaskedMultiHeadSelfAttention.

Concrete finite-mask decoder-core SSA graph #

The concrete graph below starts after the Q/K/V projection split:

[Q_heads, Kᵀ_heads, V_heads, residual_stream, gamma₁, beta₁, gamma₂, beta₂].

It then runs finite-mask split-head attention, merges the head output through a supplied affine map, adds the residual stream, and applies the two post-norm sublayers. A separate projection theorem can feed this graph from a full token/parameter context.

@[reducible, inline]
abbrev Proofs.Autograd.Transformer.ΓDecoderCore (seqLen dModel numHeads headDim : ) :

Concrete decoder-core context: masked attention core inputs, residual stream, and two LayerNorm parameter pairs.

Instances For
    @[reducible, inline]
    abbrev Proofs.Autograd.Transformer.ssDecoderCore (seqLen dModel numHeads headDim dFF : ) :

    Saved tensors for the concrete finite-mask decoder-core block.

    Instances For
      def Proofs.Autograd.Transformer.idxDecoderResidualInput {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
      Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ ss) (LayerNorm.MatShape seqLen dModel)

      Residual-stream input in the decoder-core context.

      Instances For
        def Proofs.Autograd.Transformer.idxDecoderNorm1Gamma {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
        Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

        First LayerNorm scale parameter in the decoder-core context.

        Instances For
          def Proofs.Autograd.Transformer.idxDecoderNorm1Beta {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
          Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

          First LayerNorm shift parameter in the decoder-core context.

          Instances For
            def Proofs.Autograd.Transformer.idxDecoderNorm2Gamma {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
            Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

            Second LayerNorm scale parameter in the decoder-core context.

            Instances For
              def Proofs.Autograd.Transformer.idxDecoderNorm2Beta {seqLen dModel numHeads headDim : } {ss : List Spec.Shape} :
              Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ ss) (LayerNorm.VecShape dModel)

              Second LayerNorm shift parameter in the decoder-core context.

              Instances For
                noncomputable def Proofs.Autograd.Transformer.decoderMaskedCoreDGraph {seqLen dModel numHeads headDim : } (c : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) :
                DGraph (ΓDecoderCore seqLen dModel numHeads headDim) (MultiHeadAttention.ssMaskedCore seqLen numHeads headDim)

                Masked attention core while carrying residual and LayerNorm parameters.

                Instances For
                  def Proofs.Autograd.Transformer.idxDecoderHeadOut {seqLen dModel numHeads headDim : } :
                  Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim) (MultiHeadAttention.HeadsShape seqLen numHeads headDim)

                  Split-head masked attention output after the masked core.

                  Instances For
                    def Proofs.Autograd.Transformer.idxDecoderMergedAttention {seqLen dModel numHeads headDim : } :
                    Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel]) (LayerNorm.MatShape seqLen dModel)

                    Merged attention output after the supplied output projection.

                    Instances For
                      def Proofs.Autograd.Transformer.idxDecoderResidualInputAfterMerge {seqLen dModel numHeads headDim : } :
                      Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel]) (LayerNorm.MatShape seqLen dModel)

                      Residual input weakened past the merged attention output.

                      Instances For
                        def Proofs.Autograd.Transformer.idxDecoderAttentionResidual {seqLen dModel numHeads headDim : } :
                        Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel]) (LayerNorm.MatShape seqLen dModel)

                        Residual attention stream x + masked_attention(x) before the first LayerNorm.

                        Instances For
                          def Proofs.Autograd.Transformer.decoderNorm1Inputs {seqLen dModel numHeads headDim : } :
                          LayerNorm.Inputs (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel]) seqLen dModel

                          First LayerNorm input triple for the concrete decoder-core graph.

                          Instances For
                            noncomputable def Proofs.Autograd.Transformer.decoderAfterNorm1Graph {seqLen dModel numHeads headDim : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (c ε₁ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) :
                            Graph (ΓDecoderCore seqLen dModel numHeads headDim) (MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel])

                            Decoder graph through the first post-norm masked-attention sublayer.

                            Instances For
                              def Proofs.Autograd.Transformer.idxDecoderNorm1Out {seqLen dModel numHeads headDim : } :
                              Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                              First decoder post-norm output.

                              Instances For
                                def Proofs.Autograd.Transformer.idxDecoderNorm1OutAfterFfn {seqLen dModel numHeads headDim dFF : } :
                                Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF, SeqFFNModelShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                                First decoder post-norm output weakened through FFN intermediates.

                                Instances For
                                  def Proofs.Autograd.Transformer.idxDecoderFfnHiddenPre {seqLen dModel numHeads headDim dFF : } :
                                  Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                                  First FFN affine output in the concrete decoder graph.

                                  Instances For
                                    def Proofs.Autograd.Transformer.idxDecoderFfnHiddenAct {seqLen dModel numHeads headDim dFF : } :
                                    Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                                    FFN activation output in the concrete decoder graph.

                                    Instances For
                                      def Proofs.Autograd.Transformer.idxDecoderFfnProjected {seqLen dModel numHeads headDim dFF : } :
                                      Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF, SeqFFNModelShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                                      FFN projection output in the concrete decoder graph.

                                      Instances For
                                        def Proofs.Autograd.Transformer.idxDecoderFfnResidual {seqLen dModel numHeads headDim dFF : } :
                                        Idx (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF) (SeqFFNModelShape seqLen dModel)

                                        FFN residual output before the second decoder LayerNorm.

                                        Instances For
                                          noncomputable def Proofs.Autograd.Transformer.decoderFfnResidualGraph {seqLen dModel numHeads headDim dFF : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) :
                                          Graph (ΓDecoderCore seqLen dModel numHeads headDim) (MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF)

                                          Decoder graph through the FFN residual.

                                          Instances For
                                            def Proofs.Autograd.Transformer.decoderNorm2Inputs {seqLen dModel numHeads headDim dFF : } :
                                            LayerNorm.Inputs (ΓDecoderCore seqLen dModel numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim ++ [LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel, LayerNorm.MatShape seqLen dModel] ++ ssSeqFFNResidual seqLen dModel dFF) seqLen dModel

                                            Second LayerNorm input triple for the concrete decoder-core graph.

                                            Instances For
                                              noncomputable def Proofs.Autograd.Transformer.decoderCoreGraph {seqLen dModel numHeads headDim dFF : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) :
                                              Graph (ΓDecoderCore seqLen dModel numHeads headDim) (ssDecoderCore seqLen dModel numHeads headDim dFF)

                                              Concrete SSA graph for one finite-mask GPT-style decoder-core block.

                                              Instances For
                                                noncomputable def Proofs.Autograd.Transformer.decoderFfnResidualGraphFDerivCorrectAt {seqLen dModel numHeads headDim dFF : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) (xV : CtxVec (ΓDecoderCore seqLen dModel numHeads headDim)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode decoderNorm1Inputs ε₁) ((((decoderMaskedCoreDGraph c bias).g.snoc (TapeNodes.affine idxDecoderHeadOut merge mergeBias)).snoc (TapeNodes.add idxDecoderResidualInputAfterMerge idxDecoderMergedAttention)).evalVec xV)) :
                                                GraphFDerivCorrectAt (decoderFfnResidualGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ bias) xV

                                                Pointwise analytic correctness for the decoder graph through the FFN residual.

                                                Instances For
                                                  noncomputable def Proofs.Autograd.Transformer.decoderCoreGraphFDerivCorrectAt {seqLen dModel numHeads headDim dFF : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) (xV : CtxVec (ΓDecoderCore seqLen dModel numHeads headDim)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode decoderNorm1Inputs ε₁) ((((decoderMaskedCoreDGraph c bias).g.snoc (TapeNodes.affine idxDecoderHeadOut merge mergeBias)).snoc (TapeNodes.add idxDecoderResidualInputAfterMerge idxDecoderMergedAttention)).evalVec xV)) (hNorm2 : NodeFDerivCorrectAt (LayerNorm.wholeNode decoderNorm2Inputs ε₂) ((decoderFfnResidualGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ bias).evalVec xV)) :
                                                  GraphFDerivCorrectAt (decoderCoreGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ ε₂ bias) xV

                                                  Pointwise analytic correctness for the complete concrete decoder-core graph.

                                                  Instances For
                                                    theorem Proofs.Autograd.Transformer.decoderCore_backpropVec_eq_adjoint_fderiv_at {seqLen dModel numHeads headDim dFF : } (merge : Vec (MultiHeadAttention.HeadsShape seqLen numHeads headDim).size →L[] Vec (LayerNorm.MatShape seqLen dModel).size) (mergeBias : Vec (LayerNorm.MatShape seqLen dModel).size) (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (c ε₁ ε₂ : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) (xV : CtxVec (ΓDecoderCore seqLen dModel numHeads headDim)) (seedV : CtxVec (ΓDecoderCore seqLen dModel numHeads headDim ++ ssDecoderCore seqLen dModel numHeads headDim dFF)) (hNorm1 : NodeFDerivCorrectAt (LayerNorm.wholeNode decoderNorm1Inputs ε₁) ((((decoderMaskedCoreDGraph c bias).g.snoc (TapeNodes.affine idxDecoderHeadOut merge mergeBias)).snoc (TapeNodes.add idxDecoderResidualInputAfterMerge idxDecoderMergedAttention)).evalVec xV)) (hNorm2 : NodeFDerivCorrectAt (LayerNorm.wholeNode decoderNorm2Inputs ε₂) ((decoderFfnResidualGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ bias).evalVec xV)) :
                                                    (decoderCoreGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ ε₂ bias).backpropVec xV seedV = (ContinuousLinearMap.adjoint (fderiv (decoderCoreGraph merge mergeBias fc1 b1 fc2 b2 c ε₁ ε₂ bias).evalVec xV)) seedV

                                                    End-to-end VJP theorem for the concrete finite-mask GPT-style decoder-core graph.

                                                    theorem Proofs.Autograd.Transformer.projectedMaskedDecoderAttentionPack_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {seqLen dModel numHeads headDim : } (c : ) (bias : Vec (MultiHeadAttention.ScoresShape seqLen numHeads).size := 0) (projectPack : ECtxVec (MultiHeadAttention.ΓMaskedCore seqLen numHeads headDim)) (DprojectPack : E →L[] CtxVec (MultiHeadAttention.ΓMaskedCore seqLen numHeads headDim)) (attentionPack : CtxVec (MultiHeadAttention.ΓMaskedCore seqLen numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim)CtxVec (ΓPostNorm seqLen dModel)) (DattentionPack : CtxVec (MultiHeadAttention.ΓMaskedCore seqLen numHeads headDim ++ MultiHeadAttention.ssMaskedCore seqLen numHeads headDim) →L[] CtxVec (ΓPostNorm seqLen dModel)) (x : E) (hProject : HasFDerivAt projectPack DprojectPack x) (hAttentionPack : HasFDerivAt attentionPack DattentionPack ((MultiHeadAttention.maskedCoreDGraph c bias).g.evalVec (projectPack x))) :
                                                    HasFDerivAt (fun (z : E) => attentionPack ((MultiHeadAttention.maskedCoreDGraph c bias).g.evalVec (projectPack z))) (DattentionPack ∘SL fderiv (MultiHeadAttention.maskedCoreDGraph c bias).g.evalVec (projectPack x) ∘SL DprojectPack) x

                                                    Projection-to-residual bridge for a GPT-style masked decoder attention sublayer.

                                                    The concrete decoder-core graph above starts from already split Q, Kᵀ, and V heads. This theorem is the reusable front-end hook for full GPT blocks: any differentiable projection/split stage may build those heads, and any differentiable merge/residual pack may turn the masked attention trace into the first LayerNorm input triple [x + MaskedMHA(x), gamma₁, beta₁].

                                                    Combine this theorem with postNormGptDecoderBlock_hasFDerivAt below to get the full projected finite-mask decoder-block differentiability statement.

                                                    theorem Proofs.Autograd.Transformer.postNormGptDecoderBlock_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {seqLen dModel : } (ε₁ ε₂ : ) (maskedAttentionPack : ECtxVec (ΓPostNorm seqLen dModel)) (DmaskedAttentionPack : E →L[] CtxVec (ΓPostNorm seqLen dModel)) (ffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel)CtxVec (ΓPostNorm seqLen dModel)) (DffnPack : CtxVec (ΓPostNorm seqLen dModel ++ ssPostNorm seqLen dModel) →L[] CtxVec (ΓPostNorm seqLen dModel)) (x : E) (hMaskedAttentionPack : HasFDerivAt maskedAttentionPack DmaskedAttentionPack x) (hNorm1VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₁).evalVec (maskedAttentionPack x))).ofLp i) (hNorm1StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₁).evalVec (maskedAttentionPack x))).ofLp i 0) (hFfnPack : HasFDerivAt ffnPack DffnPack ((postNormGraph ε₁).evalVec (maskedAttentionPack x))) (hNorm2VarEpsPos : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), 0 < (CtxVec.get LayerNorm.idxVarEps ((LayerNorm.layerNormPrefix6 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (maskedAttentionPack x))))).ofLp i) (hNorm2StdNe0 : ∀ (i : Fin (LayerNorm.VecShape seqLen).size), (CtxVec.get LayerNorm.idxStd ((LayerNorm.layerNormPrefix7 ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (maskedAttentionPack x))))).ofLp i 0) :
                                                    HasFDerivAt (fun (z : E) => (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (maskedAttentionPack z)))) (fderiv (postNormGraph ε₂).evalVec (ffnPack ((postNormGraph ε₁).evalVec (maskedAttentionPack x))) ∘SL DffnPack ∘SL fderiv (postNormGraph ε₁).evalVec (maskedAttentionPack x) ∘SL DmaskedAttentionPack) x

                                                    Fréchet differentiability of a GPT-style post-norm decoder block.

                                                    maskedAttentionPack builds the first LayerNorm input triple [x + MaskedMHA(x), gamma₁, beta₁]. Instantiate its differentiability hypothesis with MultiHeadAttention.projectedMaskedAttention_hasFDerivAt when the attention sublayer is built from the proved finite-mask split-head core.