TorchLean API

NN.Proofs.Autograd.Tape.Ops.Attention.MaskedMultiHeadSelfAttention

Masked Multi-Head Attention Core #

This module proves the head-wise finite-mask attention core used inside causal Transformer blocks. The surrounding projection/split/merge graph lives in MultiHeadSelfAttention.lean; the theorem here is the reusable masked replacement for the score/probability/value part of that graph:

softmax(c · QKᵀ + bias) V.

The mask is a fixed finite score bias with shape (heads × seq × seq). This covers the differentiable finite-mask convention used by GPT-style causal attention. A true -∞ hard mask is a separate semantic theorem, not the differentiable runtime computation proved here.

@[reducible, inline]

Context for a masked multi-head attention core: [Q_heads, Kᵀ_heads, V_heads].

Instances For
    @[reducible, inline]

    Saved tensors for the fixed-bias multi-head attention core.

    Instances For
      def Proofs.Autograd.MultiHeadAttention.idxMaskedCoreQ {n numHeads headDim : } {ss : List Spec.Shape} :
      Idx (ΓMaskedCore n numHeads headDim ++ ss) (HeadsShape n numHeads headDim)

      Query-head index in the masked attention core context.

      Instances For
        def Proofs.Autograd.MultiHeadAttention.idxMaskedCoreKt {n numHeads headDim : } {ss : List Spec.Shape} :
        Idx (ΓMaskedCore n numHeads headDim ++ ss) (KtShape n numHeads headDim)

        Transposed-key index in the masked attention core context.

        Instances For
          def Proofs.Autograd.MultiHeadAttention.idxMaskedCoreV {n numHeads headDim : } {ss : List Spec.Shape} :
          Idx (ΓMaskedCore n numHeads headDim ++ ss) (HeadsShape n numHeads headDim)

          Value-head index in the masked attention core context.

          Instances For
            noncomputable def Proofs.Autograd.MultiHeadAttention.maskedCoreDGraph {n numHeads headDim : } (c : ) (bias : Vec (ScoresShape n numHeads).size := 0) :
            DGraph (ΓMaskedCore n numHeads headDim) (ssMaskedCore n numHeads headDim)

            Proof-carrying masked multi-head attention core.

            The fixed bias is added after scaling the score tensor and before the row-wise softmax.

            Instances For
              theorem Proofs.Autograd.MultiHeadAttention.maskedCore_backpropVec_eq_adjoint_fderiv {n numHeads headDim : } (c : ) (bias : Vec (ScoresShape n numHeads).size := 0) (xV : CtxVec (ΓMaskedCore n numHeads headDim)) (seedV : CtxVec (ΓMaskedCore n numHeads headDim ++ ssMaskedCore n numHeads headDim)) :

              Reverse-mode theorem for the fixed-bias multi-head attention core.

              theorem Proofs.Autograd.MultiHeadAttention.maskedCoreAfterProjection_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {n numHeads headDim : } (c : ) (bias : Vec (ScoresShape n numHeads).size := 0) (projectPack : ECtxVec (ΓMaskedCore n numHeads headDim)) (DprojectPack : E →L[] CtxVec (ΓMaskedCore n numHeads headDim)) (x : E) (hProject : HasFDerivAt projectPack DprojectPack x) :
              HasFDerivAt (fun (z : E) => (maskedCoreDGraph c bias).g.evalVec (projectPack z)) (fderiv (maskedCoreDGraph c bias).g.evalVec (projectPack x) ∘SL DprojectPack) x

              Composition theorem for the masked core after a projection/split stage.

              projectPack is the mathematical interface exposed by a full MHA front half: it builds [Q_heads, Kᵀ_heads, V_heads] from an outer context. The theorem composes that front half with the proved finite-mask attention core.

              theorem Proofs.Autograd.MultiHeadAttention.projectedMaskedAttention_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {F : Type v} [NormedAddCommGroup F] [NormedSpace F] {n numHeads headDim : } (c : ) (bias : Vec (ScoresShape n numHeads).size := 0) (projectPack : ECtxVec (ΓMaskedCore n numHeads headDim)) (DprojectPack : E →L[] CtxVec (ΓMaskedCore n numHeads headDim)) (mergePack : CtxVec (ΓMaskedCore n numHeads headDim ++ ssMaskedCore n numHeads headDim)F) (DmergePack : CtxVec (ΓMaskedCore n numHeads headDim ++ ssMaskedCore n numHeads headDim) →L[] F) (x : E) (hProject : HasFDerivAt projectPack DprojectPack x) (hMerge : HasFDerivAt mergePack DmergePack ((maskedCoreDGraph c bias).g.evalVec (projectPack x))) :
              HasFDerivAt (fun (z : E) => mergePack ((maskedCoreDGraph c bias).g.evalVec (projectPack z))) (DmergePack ∘SL fderiv (maskedCoreDGraph c bias).g.evalVec (projectPack x) ∘SL DprojectPack) x

              Full masked-attention composition contract.

              The theorem separates the already-proved pieces of a GPT-style attention block:

              • a differentiable front half that projects/splits inputs into Q, Kᵀ, and V;
              • the proved finite-mask split-head attention core;
              • a differentiable back half that merges/project outputs or packages residual data for the caller.

              Instantiating projectPack and mergePack with the concrete projection/split/merge graphs gives the full masked-MHA differentiability statement without changing the core proof.