Attention (spec layer) #
This file defines the standard scaled dot-product attention primitive and a simple multi-head wrapper.
Attention(Q,K,V) = softmax(Q Kᵀ / √d) V
TorchLean goal here is to mirror the math you see in deep learning libraries (especially PyTorch),
but keep everything as pure functions on Spec.Tensor so the same definitions can be reused for:
- proofs (e.g. reasoning about shapes and gradients),
- reference implementations (runtime extraction),
- verification backends (e.g. interval semantics).
Shapes and conventions #
We model the "single batch element" case. Batched attention is obtained by adding an outer .dim B
and mapping over it.
Core shapes:
In many transformer blocks dV = d, and this file uses that common choice for simplicity.
The optional Boolean mask has shape (nQ × nK). In the main spec, masks use the true -∞
semantics: blocked entries receive zero numerator before row normalization, so their attention
weight is definitionally zero. This is the finite-scalar encoding of the PyTorch pattern
scores.masked_fill(~mask, -torch.inf).
PyTorch analogy:
scaledDotProductAttentioncorresponds totorch.nn.functional.scaled_dot_product_attention(no dropout), with Boolean masks interpreted as true-∞masks.MultiHeadAttention.forwardcorresponds to the core computation insidenn.MultiheadAttention/ transformer blocks, ignoring biases and dropout.
Scaled Dot-Product Attention #
We separate out the single-head primitive (scaledDotProductAttention) because:
- it is the core mathematical object, reused in multi-head attention,
- it is a good target for proofs and for "spec vs runtime" comparisons.
Boolean masks #
TorchLean uses the same boolean mask convention as PyTorch SDPA:
truemeans a key/value position is allowed to be attended to,falsemeans it is blocked (its softmax numerator is exactly zero).
PyTorch reference: torch.nn.functional.scaled_dot_product_attention uses the same convention for
boolean attn_mask entries: True entries are included, and False entries are blocked.
A (nQ × nK) mask where every position is allowed (true).
Instances For
A (nQ × nK) mask where every position is blocked (false).
Instances For
Causal (lower-triangular) self-attention mask of shape (n, n).
mask[i,j] = true iff j ≤ i, i.e. each query position can attend to itself and past positions.
Instances For
Future-only (upper-triangular) self-attention mask of shape (n, n).
This is the (strict) complement of causal_mask: mask[i,j] = true iff i < j.
Instances For
Bundled inputs and mask needed for scaled dot-product attention.
- Q : Tensor α (Shape.dim nQ (Shape.dim dModel Shape.scalar))
- K : Tensor α (Shape.dim nK (Shape.dim dModel Shape.scalar))
- V : Tensor α (Shape.dim nK (Shape.dim dModel Shape.scalar))
- bc_sum_to_target : (Shape.dim nQ Shape.scalar).BroadcastTo (Shape.dim nQ (Shape.dim nK Shape.scalar))
Instances For
Exact hard masking #
TorchLean encodes the usual "true -∞ before softmax" behavior without requiring the tensor scalar
type itself to contain infinities. Instead of replacing blocked logits by a finite sentinel, we form
softmax numerators directly:
numerator_j = if mask_j then exp(score_j) else 0.
This is exactly what exp(-∞)=0 contributes to softmax. Blocked positions therefore have exactly
zero attention mass, which is the property causal proofs need.
Hard-masked softmax on one vector.
mask[j] = false makes the j-th numerator exactly zero before normalization. This is the
ordinary finite-scalar encoding of softmax with true -∞ masked logits.
If every mask entry is false, the denominator is zero; the definition remains total because the spec scalar division is total, but well-formed attention call sites should provide at least one allowed key per row. Causal self-attention satisfies that by construction because each token can attend to itself.
Instances For
Row-wise hard-masked softmax for attention score matrices.
Instances For
VJP/JVP helper for a softmax-like row-normalization when the forward weights are already known.
For ordinary softmax, weights = softmax(scores). For hard-masked softmax, blocked entries have
weights = 0, and the same formula gives zero gradient through blocked logits:
dScores = weights ⊙ (dWeights - Σⱼ dWeightsⱼ * weightsⱼ).
Instances For
Scaled dot-product attention (forward).
Given:
we compute:
- scores
S = Q Kᵀwith shape(nQ × nK) - scaled scores
S' = S / √d - (optional) mask: for each
(i,j), ifmask[i,j] = false, its softmax numerator is exactly zero (the finite-scalar encoding of true-∞masking) - attention weights
A = softmax(S')(softmax over the last axis, i.e. each query row sums to 1) - output
Out = A Vwith shape(nQ × d)
Mask convention:
mask[i,j] = true means "this key position is allowed", and false means "mask it out".
PyTorch analogy: torch.softmax(scores.masked_fill(~mask, -torch.inf), dim=-1) row-wise, then a
final matrix multiply by V.
Instances For
Alias documenting that the main attention spec uses exact hard-mask semantics.
Instances For
Backward/VJP for scaled dot-product attention.
Returns (dQ, dK, dV) given an upstream gradient dOut.
We recompute the forward intermediates locally so this spec stays self-contained and does not rely on a global tape.
For masked calls, this is the VJP for true hard masking. Blocked logits have zero forward weight,
and softmaxBackwardFromWeightsSpec therefore gives zero gradient through those blocked positions.
Instances For
Forward-mode JVP for scaled dot-product attention.
This differentiates the pure attention equation
Out = softmax(mask(Q Kᵀ / sqrt(d))) V
in the direction (dQ,dK,dV). For hard-masked calls, blocked logits have zero forward weight, so
their tangent contribution is zero in softmaxBackwardFromWeightsSpec. The row-wise softmax
Jacobian is symmetric, so the same formula serves as both VJP and JVP once the forward weights are
known.
Instances For
Multi-head attention parameters (projection matrices).
PyTorch analogy: this corresponds to the four linear maps used in attention blocks:
This spec keeps them as explicit matrices (no bias terms) to keep the math simple and to make the gradients easy to audit.
- Wq : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Wq.
- Wk : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Wk.
- Wv : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Wv.
- Wo : Tensor α (Shape.dim (numHeads * headDim) (Shape.dim dModel Shape.scalar))
Wo.
Instances For
Split (n, dModel) into (numHeads, n, headDim) by reshaping.
We store heads as the outermost axis so that "per-head computation" is just a Tensor.dim over
Fin numHeads.
PyTorch analogy: conceptually similar to reshaping (n, numHeads*headDim) into
(n, numHeads, headDim) and then transposing to make heads a separate axis; here we go directly to
(numHeads, n, headDim) because it is convenient for later definitions.
Instances For
Concatenate a list of numHeads head tensors into a single (n, numHeads*headDim) tensor.
This is a straightforward list-based definition. The newer combine_heads_spec below does the same
thing starting from a tensor-of-heads representation.
Instances For
concat_heads_spec above is the original (list-based) definition.
For proofs/automation, it's often easier to work with a tensor of heads
Tensor α (.dim numHeads (.dim n (.dim headDim .scalar))) and then use shape-only transforms
to combine heads back into a single (n, numHeads*headDim) tensor.
Combine a tensor-of-heads back into a single (n, numHeads*headDim) tensor.
Implementation detail:
swap_first_two_specconverts(numHeads, n, headDim)into(n, numHeads, headDim)reshape_specflattens the last two axes into(n, numHeads*headDim)
Instances For
Convenience proof that (n) broadcasts to (n,n).
This is kept as a small helper because some attention-style proofs and wrappers want an explicit
BroadcastTo witness rather than relying on typeclass search.
Instances For
Multi-head attention forward pass (self-attention when mask is square).
High-level structure (PyTorch mental model):
- project
xintoQ,K,V - split the projection dimension into heads
- run scaled dot-product attention per head (sharing the same mask)
- combine heads back and project with
Wo
Instances For
Multi-head attention backward pass.
Returns gradients for input x and all projection matrices (Wq,Wk,Wv,Wo).
We recompute forward intermediates locally so we don’t rely on a global tape.
Instances For
Forward-mode JVP for multi-head attention.
The rule follows the same computational graph as MultiHeadAttention.forward:
- project tangents through
Q/K/V, - split primal and tangent projections into heads,
- apply
scaledDotProductAttentionJvphead-wise, - combine head tangents, then differentiate the final output projection.
This keeps attention forward-mode AD explicit at the spec layer instead of hiding it behind a runtime-only implementation.
Instances For
Self-attention on a single sequence.
This uses the same input x for Q/K/V, runs scaled dot-product attention, then applies the output
projection Wo.
PyTorch mental model: the core of nn.MultiheadAttention / TransformerEncoderLayer (ignoring the
batch axis).
Instances For
Cross-attention between two sequences.
query is length n1 and attends to key/value of length n2.
PyTorch mental model: the attention block in a Transformer decoder layer (nn.MultiheadAttention
with distinct query and key/value inputs).
Instances For
Sparse Attention Uses sparse attention patterns for efficiency