MultiHeadSelfAttention #
End-to-end fderiv/backprop correctness for a Multi-Head Self-Attention graph,
decomposed into proven tape nodes:
- linear projections via
matmul, - head split/merge via
reshape+swap_first_two3d, - attention core via batched
matmul+transpose3d_last_two+scale+ batchedsoftmax_last.
This is spec-level over ℝ. It is a corollary of the general graph theorem once each node
used by the graph has a NodeFDerivCorrect instance.
PyTorch correspondence / citations #
- The construction matches the usual “project → split heads → scaled dot-product attention →
concat heads → output projection” pipeline used by
torch.nn.MultiheadAttention. https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html - The core attention step corresponds to
torch.nn.functional.scaled_dot_product_attention. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
Sequence input shape n×dModel.
Instances For
Concatenated-head representation n×(numHeads*headDim).
Instances For
Split-head representation (numHeads)×n×headDim.
Instances For
Key-transposed shape (numHeads)×headDim×n used for Q Kᵀ.
Instances For
Attention scores shape (numHeads)×n×n.
Instances For
Intermediate shape after swapping axes for concatenation n×numHeads×headDim.
Instances For
Intermediate node output shapes (tape “saved tensors”) for the MHA graph.
Instances For
Projection weight shape dModel×(numHeads*headDim) (used for Q/K/V).
Instances For
Output projection weight shape (numHeads*headDim)×dModel.
Instances For
Input context shapes: [x, Wq, Wk, Wv, Wo].
Instances For
Context index of the sequence input x in ΓMHA.
Instances For
Context index of Wq in ΓMHA.
Instances For
Context index of Wk in ΓMHA.
Instances For
Context index of Wv in ΓMHA.
Instances For
Context index of Wo in ΓMHA.
Instances For
Index of the most-recently appended tensor in a DGraph context.
Instances For
Multi-head self-attention as a proof-carrying graph.
This implements:
x Wq Wk Wv Wo ↦ Wo (concat_heads (softmax(c * (Q Kᵀ)) V)),
with Q/K/V projected from x.
The graph is laid out to match typical runtime implementations:
view(...).transpose(...) is modeled by reshape + swap_first_two3d.
Instances For
Corollary of the general DAG theorem: backprop equals (fderiv eval)† for the MHA graph.
This is the formal “VJP correctness” statement for the full MHA computation (as laid out by
mhaDGraph).