TorchLean API

NN.Proofs.Autograd.Tape.Ops.Transformer.FeedForward

Transformer Feed-Forward Sublayer VJP #

This file proves the standard position-wise Transformer feed-forward sublayer, at the vector level:

x ↦ x + W₂ GELU(W₁ x + b₁) + b₂.

The theorem is intentionally about one token/vector. Batched sequence application is a map over positions, and the full Transformer encoder block additionally composes this FFN residual with MHA and LayerNorm. This file gives the clean proof component for the FFN half of that block.

References:

@[reducible, inline]

Model-vector shape.

Instances For
    @[reducible, inline]

    Hidden feed-forward shape.

    Instances For
      @[reducible, inline]

      Context for one position-wise FFN: just the input vector.

      Instances For
        @[reducible, inline]

        Saved tensors: first affine, activation, second affine, residual output.

        Instances For
          def Proofs.Autograd.Transformer.ffnIdxX {dModel : } {ss : List Spec.Shape} :
          Idx (ΓFFN dModel ++ ss) (FFNModelShape dModel)

          Input vector index.

          Instances For

            Most recently appended tensor helper.

            Instances For

              First affine output index.

              Instances For

                GELU activation output index.

                Instances For

                  Second affine output index.

                  Instances For
                    noncomputable def Proofs.Autograd.Transformer.ffnResidualDGraph {dModel dFF : } (fc1 : Spec.LinearSpec dModel dFF) (fc2 : Spec.LinearSpec dFF dModel) :
                    DGraph (ΓFFN dModel) (ssFFNResidual dModel dFF)

                    Proof-carrying graph for a residual Transformer FFN sublayer.

                    The two affine maps are fixed LinearSpecs here, so the theorem covers the VJP with respect to the input vector. Parameter-gradient theorems live at the trainable-parameter/runtime layer.

                    Instances For
                      theorem Proofs.Autograd.Transformer.ffnResidual_backpropVec_eq_adjoint_fderiv {dModel dFF : } (fc1 : Spec.LinearSpec dModel dFF) (fc2 : Spec.LinearSpec dFF dModel) (xV : CtxVec (ΓFFN dModel)) (seedV : CtxVec (ΓFFN dModel ++ ssFFNResidual dModel dFF)) :

                      End-to-end VJP theorem for the residual Transformer feed-forward sublayer.

                      Sequence-shaped FFN residual #

                      The runtime Transformer applies the same FFN to every token in a (seqLen × dModel) tensor. For the model-level proof interface below, we package that operation as two fixed affine maps over the flattened sequence tensor. A concrete shared-weight implementation instantiates these maps with the usual block-diagonal/time-distributed linear operator; the VJP theorem itself only needs the affine maps and the smooth GELU primitive.

                      @[reducible, inline]

                      Sequence-shaped model stream.

                      Instances For
                        @[reducible, inline]

                        Sequence-shaped FFN hidden stream.

                        Instances For
                          @[reducible, inline]

                          Context for a sequence-level FFN residual block: just the sequence stream.

                          Instances For
                            @[reducible, inline]

                            Saved tensors for the sequence-level residual FFN.

                            Instances For
                              def Proofs.Autograd.Transformer.seqFfnIdxX {seqLen dModel : } {ss : List Spec.Shape} :
                              Idx (ΓSeqFFN seqLen dModel ++ ss) (SeqFFNModelShape seqLen dModel)

                              Sequence input index, weakened through saved tensors.

                              Instances For
                                def Proofs.Autograd.Transformer.seqFfnIdxHiddenPre {seqLen dModel dFF : } :
                                Idx (ΓSeqFFN seqLen dModel ++ [SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                                First sequence affine output.

                                Instances For
                                  def Proofs.Autograd.Transformer.seqFfnIdxHiddenAct {seqLen dModel dFF : } :
                                  Idx (ΓSeqFFN seqLen dModel ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF]) (SeqFFNHiddenShape seqLen dFF)

                                  GELU activation output.

                                  Instances For
                                    def Proofs.Autograd.Transformer.seqFfnIdxProjected {seqLen dModel dFF : } :
                                    Idx (ΓSeqFFN seqLen dModel ++ [SeqFFNHiddenShape seqLen dFF, SeqFFNHiddenShape seqLen dFF, SeqFFNModelShape seqLen dModel]) (SeqFFNModelShape seqLen dModel)

                                    Second sequence affine output.

                                    Instances For
                                      noncomputable def Proofs.Autograd.Transformer.seqFfnResidualDGraph {seqLen dModel dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) :
                                      DGraph (ΓSeqFFN seqLen dModel) (ssSeqFFNResidual seqLen dModel dFF)

                                      Proof-carrying graph for the sequence-shaped residual FFN:

                                      X ↦ X + A₂(GELU(A₁ X + b₁)) + b₂.

                                      The affine maps are supplied explicitly over flattened sequence tensors. This keeps the theorem usable for shared-weight position-wise FFNs, fused FFN kernels, and future compiler-generated linearizations, as long as they expose the same affine map.

                                      Instances For
                                        theorem Proofs.Autograd.Transformer.seqFfnResidual_backpropVec_eq_adjoint_fderiv {seqLen dModel dFF : } (fc1 : Vec (SeqFFNModelShape seqLen dModel).size →L[] Vec (SeqFFNHiddenShape seqLen dFF).size) (b1 : Vec (SeqFFNHiddenShape seqLen dFF).size) (fc2 : Vec (SeqFFNHiddenShape seqLen dFF).size →L[] Vec (SeqFFNModelShape seqLen dModel).size) (b2 : Vec (SeqFFNModelShape seqLen dModel).size) (xV : CtxVec (ΓSeqFFN seqLen dModel)) (seedV : CtxVec (ΓSeqFFN seqLen dModel ++ ssSeqFFNResidual seqLen dModel dFF)) :

                                        End-to-end VJP theorem for the sequence-shaped residual FFN.