TorchLean API

NN.Proofs.Autograd.Tape.Ops.Transformer.ResidualAttention

Residual Attention Blocks #

This file proves the next composition step after full multi-head self-attention:

x ↦ x + MHA(x).

That residual add is the first half of a post-norm Transformer encoder sublayer:

LayerNorm(x + MultiHeadSelfAttention(x)).

The existing MHA theorem already proves that the attention graph's reverse pass is the adjoint of the Fréchet derivative. Here we append the residual add as one more proved tape node, giving a reusable graph theorem for the residual stream that is passed to post-norm LayerNorm in the runtime Transformer blocks.

References:

@[reducible, inline]
abbrev Proofs.Autograd.Transformer.ssMHAResidual (n dModel numHeads headDim : ) :

Intermediate list for MHA followed by one residual-add output.

Instances For
    def Proofs.Autograd.Transformer.residualIdxX {n dModel numHeads headDim : } :
    Idx (MultiHeadAttention.ΓMHA n dModel numHeads headDim ++ MultiHeadAttention.ssMHA n dModel numHeads headDim) (MultiHeadAttention.XShape n dModel)

    Original sequence input x, weakened into the context after the MHA intermediates.

    Instances For
      def Proofs.Autograd.Transformer.residualIdxAttnOut {n dModel numHeads headDim : } :
      Idx (MultiHeadAttention.ΓMHA n dModel numHeads headDim ++ MultiHeadAttention.ssMHA n dModel numHeads headDim) (MultiHeadAttention.XShape n dModel)

      The final output of mhaDGraph, i.e. the projected attention result.

      The literal index is intentional: MultiHeadAttention.ssMHA is a fixed 14-entry saved-tensor list, and the final entry is the attention output with the same shape as the input sequence.

      Instances For
        noncomputable def Proofs.Autograd.Transformer.mhaResidualDGraph {n dModel numHeads headDim : } (c : ) :
        DGraph (MultiHeadAttention.ΓMHA n dModel numHeads headDim) (ssMHAResidual n dModel numHeads headDim)

        Proof-carrying graph for x + MHA(x).

        Context layout is inherited from MHA: [x, Wq, Wk, Wv, Wo].

        Instances For
          theorem Proofs.Autograd.Transformer.mhaResidual_backpropVec_eq_adjoint_fderiv {n dModel numHeads headDim : } (c : ) (xV : CtxVec (MultiHeadAttention.ΓMHA n dModel numHeads headDim)) (seedV : CtxVec (MultiHeadAttention.ΓMHA n dModel numHeads headDim ++ ssMHAResidual n dModel numHeads headDim)) :

          End-to-end VJP theorem for the residual-attention sublayer x + MHA(x).

          This is the proved residual-stream component used by post-norm Transformer blocks. LayerNorm itself has its own current-spec VJP theorem in NN.Proofs.Autograd.Tape.Ops.Norm.LayerNorm; the composed post-norm attention sublayer and the two-sublayer post-norm bridge are packaged in NN.Proofs.Autograd.Tape.Ops.Transformer.PostNorm.