TorchLean API

NN.Proofs.Autograd.Tape.Ops.Attention.MultiHeadSelfAttention

MultiHeadSelfAttention #

End-to-end fderiv/backprop correctness for a Multi-Head Self-Attention graph, decomposed into proven tape nodes:

This is spec-level over . It is a corollary of the general graph theorem once each node used by the graph has a NodeFDerivCorrect instance.

PyTorch correspondence / citations #

@[reducible, inline]

Sequence input shape n×dModel.

Instances For
    @[reducible, inline]

    Concatenated-head representation n×(numHeads*headDim).

    Instances For
      @[reducible, inline]

      Split-head representation (numHeads)×n×headDim.

      Instances For
        @[reducible, inline]

        Key-transposed shape (numHeads)×headDim×n used for Q Kᵀ.

        Instances For
          @[reducible, inline]

          Attention scores shape (numHeads)×n×n.

          Instances For
            @[reducible, inline]

            Intermediate shape after swapping axes for concatenation n×numHeads×headDim.

            Instances For
              @[reducible, inline]
              abbrev Proofs.Autograd.MultiHeadAttention.ssMHA (n dModel numHeads headDim : ) :

              Intermediate node output shapes (tape “saved tensors”) for the MHA graph.

              Instances For
                @[reducible, inline]
                abbrev Proofs.Autograd.MultiHeadAttention.WqShape (dModel numHeads headDim : ) :

                Projection weight shape dModel×(numHeads*headDim) (used for Q/K/V).

                Instances For
                  @[reducible, inline]
                  abbrev Proofs.Autograd.MultiHeadAttention.WoShape (dModel numHeads headDim : ) :

                  Output projection weight shape (numHeads*headDim)×dModel.

                  Instances For
                    @[reducible, inline]
                    abbrev Proofs.Autograd.MultiHeadAttention.ΓMHA (n dModel numHeads headDim : ) :

                    Input context shapes: [x, Wq, Wk, Wv, Wo].

                    Instances For
                      def Proofs.Autograd.MultiHeadAttention.idxX {n dModel numHeads headDim : } {ss : List Spec.Shape} :
                      Idx (ΓMHA n dModel numHeads headDim ++ ss) (XShape n dModel)

                      Context index of the sequence input x in ΓMHA.

                      Instances For
                        def Proofs.Autograd.MultiHeadAttention.idxWq {n dModel numHeads headDim : } {ss : List Spec.Shape} :
                        Idx (ΓMHA n dModel numHeads headDim ++ ss) (WqShape dModel numHeads headDim)

                        Context index of Wq in ΓMHA.

                        Instances For
                          def Proofs.Autograd.MultiHeadAttention.idxWk {n dModel numHeads headDim : } {ss : List Spec.Shape} :
                          Idx (ΓMHA n dModel numHeads headDim ++ ss) (WqShape dModel numHeads headDim)

                          Context index of Wk in ΓMHA.

                          Instances For
                            def Proofs.Autograd.MultiHeadAttention.idxWv {n dModel numHeads headDim : } {ss : List Spec.Shape} :
                            Idx (ΓMHA n dModel numHeads headDim ++ ss) (WqShape dModel numHeads headDim)

                            Context index of Wv in ΓMHA.

                            Instances For
                              def Proofs.Autograd.MultiHeadAttention.idxWo {n dModel numHeads headDim : } {ss : List Spec.Shape} :
                              Idx (ΓMHA n dModel numHeads headDim ++ ss) (WoShape dModel numHeads headDim)

                              Context index of Wo in ΓMHA.

                              Instances For

                                Index of the most-recently appended tensor in a DGraph context.

                                Instances For
                                  theorem Proofs.Autograd.MultiHeadAttention.size_big_to_heads (n numHeads headDim : ) :
                                  (BigShape n numHeads headDim).size = (HeadsShape n numHeads headDim).size
                                  noncomputable def Proofs.Autograd.MultiHeadAttention.mhaDGraph {n dModel numHeads headDim : } (c : ) :
                                  DGraph (ΓMHA n dModel numHeads headDim) (ssMHA n dModel numHeads headDim)

                                  Multi-head self-attention as a proof-carrying graph.

                                  This implements: x Wq Wk Wv Wo ↦ Wo (concat_heads (softmax(c * (Q Kᵀ)) V)), with Q/K/V projected from x.

                                  The graph is laid out to match typical runtime implementations: view(...).transpose(...) is modeled by reshape + swap_first_two3d.

                                  Instances For
                                    theorem Proofs.Autograd.MultiHeadAttention.mha_backpropVec_eq_adjoint_fderiv {n dModel numHeads headDim : } (c : ) (xV : CtxVec (ΓMHA n dModel numHeads headDim)) (seedV : CtxVec (ΓMHA n dModel numHeads headDim ++ ssMHA n dModel numHeads headDim)) :

                                    Corollary of the general DAG theorem: backprop equals (fderiv eval)† for the MHA graph.

                                    This is the formal “VJP correctness” statement for the full MHA computation (as laid out by mhaDGraph).