Attention module wrappers #
This file wraps a few attention blocks as NNModuleSpecs so we can:
- compose them with
SpecChain(shape-safe pipelines), and - attach simple export/pretty-print metadata for demos.
The wrapper below builds a self-attention context with Q=K=V=x and no mask, which matches the
common "encoder block" usage. More specialized variants (cross-attention, causal masks, etc.) are
defined at the layer-spec level in NN/Spec/Layers/Attention.lean.
In PyTorch terms, the core computation is scaled dot-product self-attention:
softmax(QK^T / sqrt(d)) V, and newer PyTorch exposes it as
torch.nn.functional.scaled_dot_product_attention.
This wrapper stays intentionally narrow: it is self-attention only (Q=K=V=x) with no causal mask.
def
Spec.ScaledDotProductAttentionModuleSpec
{α : Type}
[Context α]
[DecidableRel fun (x1 x2 : α) => x1 > x2]
(n dModel : ℕ)
(h1 : n ≠ 0)
:
ModSpec.NNModuleSpec α (Shape.dim n (Shape.dim dModel Shape.scalar)) (Shape.dim n (Shape.dim dModel Shape.scalar))
Self-attention block (Q=K=V=x, no mask) as an NNModuleSpec.