TorchLean API

NN.Spec.Layers.Attention

Attention (spec layer) #

This file defines the standard scaled dot-product attention primitive and a simple multi-head wrapper.

Attention(Q,K,V) = softmax(Q Kᵀ / √d) V

TorchLean goal here is to mirror the math you see in deep learning libraries (especially PyTorch), but keep everything as pure functions on Spec.Tensor so the same definitions can be reused for:

Shapes and conventions #

We model the "single batch element" case. Batched attention is obtained by adding an outer .dim B and mapping over it.

Core shapes:

In many transformer blocks dV = d, and this file uses that common choice for simplicity.

The optional Boolean mask has shape (nQ × nK). In the main spec, masks use the true -∞ semantics: blocked entries receive zero numerator before row normalization, so their attention weight is definitionally zero. This is the finite-scalar encoding of the PyTorch pattern scores.masked_fill(~mask, -torch.inf).

PyTorch analogy:

Scaled Dot-Product Attention #

We separate out the single-head primitive (scaledDotProductAttention) because:

Boolean masks #

TorchLean uses the same boolean mask convention as PyTorch SDPA:

PyTorch reference: torch.nn.functional.scaled_dot_product_attention uses the same convention for boolean attn_mask entries: True entries are included, and False entries are blocked.

A (nQ × nK) mask where every position is allowed (true).

Instances For

    A (nQ × nK) mask where every position is blocked (false).

    Instances For

      Causal (lower-triangular) self-attention mask of shape (n, n).

      mask[i,j] = true iff j ≤ i, i.e. each query position can attend to itself and past positions.

      Instances For

        Future-only (upper-triangular) self-attention mask of shape (n, n).

        This is the (strict) complement of causal_mask: mask[i,j] = true iff i < j.

        Instances For
          structure Spec.AttentionContext (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (nQ nK dModel : ) (h1 : nQ 0) (h2 : nK 0) :

          Bundled inputs and mask needed for scaled dot-product attention.

          Instances For

            Exact hard masking #

            TorchLean encodes the usual "true -∞ before softmax" behavior without requiring the tensor scalar type itself to contain infinities. Instead of replacing blocked logits by a finite sentinel, we form softmax numerators directly:

            numerator_j = if mask_j then exp(score_j) else 0.

            This is exactly what exp(-∞)=0 contributes to softmax. Blocked positions therefore have exactly zero attention mass, which is the property causal proofs need.

            Hard-masked softmax on one vector.

            mask[j] = false makes the j-th numerator exactly zero before normalization. This is the ordinary finite-scalar encoding of softmax with true -∞ masked logits.

            If every mask entry is false, the denominator is zero; the definition remains total because the spec scalar division is total, but well-formed attention call sites should provide at least one allowed key per row. Causal self-attention satisfies that by construction because each token can attend to itself.

            Instances For

              Row-wise hard-masked softmax for attention score matrices.

              Instances For
                def Spec.softmaxBackwardFromWeightsSpec {α : Type} [Context α] {s : Shape} :
                Tensor α sTensor α sTensor α s

                VJP/JVP helper for a softmax-like row-normalization when the forward weights are already known.

                For ordinary softmax, weights = softmax(scores). For hard-masked softmax, blocked entries have weights = 0, and the same formula gives zero gradient through blocked logits:

                dScores = weights ⊙ (dWeights - Σⱼ dWeightsⱼ * weightsⱼ).

                Instances For
                  def Spec.scaledDotProductAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                  Scaled dot-product attention (forward).

                  Given:

                  • Q : (nQ × d), K : (nK × d), V : (nK × d),

                  we compute:

                  1. scores S = Q Kᵀ with shape (nQ × nK)
                  2. scaled scores S' = S / √d
                  3. (optional) mask: for each (i,j), if mask[i,j] = false, its softmax numerator is exactly zero (the finite-scalar encoding of true -∞ masking)
                  4. attention weights A = softmax(S') (softmax over the last axis, i.e. each query row sums to 1)
                  5. output Out = A V with shape (nQ × d)

                  Mask convention:

                  mask[i,j] = true means "this key position is allowed", and false means "mask it out".

                  PyTorch analogy: torch.softmax(scores.masked_fill(~mask, -torch.inf), dim=-1) row-wise, then a final matrix multiply by V.

                  Instances For
                    def Spec.hardMaskedScaledDotProductAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                    Alias documenting that the main attention spec uses exact hard-mask semantics.

                    Instances For
                      def Spec.scaledDotProductAttentionBackward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) (dOut : Tensor α (Shape.dim nQ (Shape.dim dModel Shape.scalar))) :

                      Backward/VJP for scaled dot-product attention.

                      Returns (dQ, dK, dV) given an upstream gradient dOut.

                      We recompute the forward intermediates locally so this spec stays self-contained and does not rely on a global tape.

                      For masked calls, this is the VJP for true hard masking. Blocked logits have zero forward weight, and softmaxBackwardFromWeightsSpec therefore gives zero gradient through those blocked positions.

                      Instances For
                        def Spec.scaledDotProductAttentionJvp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) (dQ : Tensor α (Shape.dim nQ (Shape.dim dModel Shape.scalar))) (dK dV : Tensor α (Shape.dim nK (Shape.dim dModel Shape.scalar))) :

                        Forward-mode JVP for scaled dot-product attention.

                        This differentiates the pure attention equation

                        Out = softmax(mask(Q Kᵀ / sqrt(d))) V

                        in the direction (dQ,dK,dV). For hard-masked calls, blocked logits have zero forward weight, so their tangent contribution is zero in softmaxBackwardFromWeightsSpec. The row-wise softmax Jacobian is symmetric, so the same formula serves as both VJP and JVP once the forward weights are known.

                        Instances For
                          structure Spec.MultiHeadAttention (α : Type) (numHeads dModel headDim : ) :

                          Multi-head attention parameters (projection matrices).

                          PyTorch analogy: this corresponds to the four linear maps used in attention blocks:

                          • Wq, Wk, Wv project dModel -> (numHeads * headDim)
                          • Wo projects (numHeads * headDim) -> dModel

                          This spec keeps them as explicit matrices (no bias terms) to keep the math simple and to make the gradients easy to audit.

                          Instances For
                            def Spec.splitHeadsSpec {α : Type} [Inhabited α] {n dModel : } (x : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (numHeads headDim : ) (h : dModel = numHeads * headDim) :
                            Tensor α (Shape.dim numHeads (Shape.dim n (Shape.dim headDim Shape.scalar)))

                            Split (n, dModel) into (numHeads, n, headDim) by reshaping.

                            We store heads as the outermost axis so that "per-head computation" is just a Tensor.dim over Fin numHeads.

                            PyTorch analogy: conceptually similar to reshaping (n, numHeads*headDim) into (n, numHeads, headDim) and then transposing to make heads a separate axis; here we go directly to (numHeads, n, headDim) because it is convenient for later definitions.

                            Instances For
                              def Spec.concatHeadsSpec {α : Type} [Context α] {n numHeads headDim : } (heads : List (Tensor α (Shape.dim n (Shape.dim headDim Shape.scalar)))) (h : heads.length = numHeads) :
                              Tensor α (Shape.dim n (Shape.dim (numHeads * headDim) Shape.scalar))

                              Concatenate a list of numHeads head tensors into a single (n, numHeads*headDim) tensor.

                              This is a straightforward list-based definition. The newer combine_heads_spec below does the same thing starting from a tensor-of-heads representation.

                              Instances For

                                concat_heads_spec above is the original (list-based) definition.

                                For proofs/automation, it's often easier to work with a tensor of heads Tensor α (.dim numHeads (.dim n (.dim headDim .scalar))) and then use shape-only transforms to combine heads back into a single (n, numHeads*headDim) tensor.

                                def Spec.combineHeadsSpec {α : Type} [Context α] {n numHeads headDim : } (heads : Tensor α (Shape.dim numHeads (Shape.dim n (Shape.dim headDim Shape.scalar)))) :
                                Tensor α (Shape.dim n (Shape.dim (numHeads * headDim) Shape.scalar))

                                Combine a tensor-of-heads back into a single (n, numHeads*headDim) tensor.

                                Implementation detail:

                                1. swap_first_two_spec converts (numHeads, n, headDim) into (n, numHeads, headDim)
                                2. reshape_spec flattens the last two axes into (n, numHeads*headDim)
                                Instances For
                                  @[reducible]

                                  Convenience proof that (n) broadcasts to (n,n).

                                  This is kept as a small helper because some attention-style proofs and wrappers want an explicit BroadcastTo witness rather than relying on typeclass search.

                                  Instances For
                                    def Spec.MultiHeadAttention.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {numHeads dModel headDim : } (n : ) (h1 : n 0) (mha : MultiHeadAttention α numHeads dModel headDim) (x : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (mask : Option (Tensor Bool (Shape.dim n (Shape.dim n Shape.scalar)))) :

                                    Multi-head attention forward pass (self-attention when mask is square).

                                    High-level structure (PyTorch mental model):

                                    1. project x into Q,K,V
                                    2. split the projection dimension into heads
                                    3. run scaled dot-product attention per head (sharing the same mask)
                                    4. combine heads back and project with Wo
                                    Instances For
                                      def Spec.MultiHeadAttentionBackward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n numHeads dModel headDim : } (h1 : n 0) (mha : MultiHeadAttention α numHeads dModel headDim) (x : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (mask : Option (Tensor Bool (Shape.dim n (Shape.dim n Shape.scalar)))) (grad_output : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) :
                                      Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar)) × Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar)) × Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar)) × Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar)) × Tensor α (Shape.dim (numHeads * headDim) (Shape.dim dModel Shape.scalar))

                                      Multi-head attention backward pass.

                                      Returns gradients for input x and all projection matrices (Wq,Wk,Wv,Wo). We recompute forward intermediates locally so we don’t rely on a global tape.

                                      Instances For
                                        def Spec.MultiHeadAttentionJvp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n numHeads dModel headDim : } (h1 : n 0) (mha dmha : MultiHeadAttention α numHeads dModel headDim) (x dx : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (mask : Option (Tensor Bool (Shape.dim n (Shape.dim n Shape.scalar)))) :

                                        Forward-mode JVP for multi-head attention.

                                        The rule follows the same computational graph as MultiHeadAttention.forward:

                                        1. project tangents through Q/K/V,
                                        2. split primal and tangent projections into heads,
                                        3. apply scaledDotProductAttentionJvp head-wise,
                                        4. combine head tangents, then differentiate the final output projection.

                                        This keeps attention forward-mode AD explicit at the spec layer instead of hiding it behind a runtime-only implementation.

                                        Instances For
                                          def Spec.selfAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n dModel projDim : } (x : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (Wq Wk Wv : Tensor α (Shape.dim dModel (Shape.dim projDim Shape.scalar))) (Wo : Tensor α (Shape.dim projDim (Shape.dim dModel Shape.scalar))) (h1 : n 0) :

                                          Self-attention on a single sequence.

                                          This uses the same input x for Q/K/V, runs scaled dot-product attention, then applies the output projection Wo.

                                          PyTorch mental model: the core of nn.MultiheadAttention / TransformerEncoderLayer (ignoring the batch axis).

                                          Instances For
                                            def Spec.crossAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n1 n2 dModel projDim : } (query : Tensor α (Shape.dim n1 (Shape.dim dModel Shape.scalar))) (key value : Tensor α (Shape.dim n2 (Shape.dim dModel Shape.scalar))) (Wq Wk Wv : Tensor α (Shape.dim dModel (Shape.dim projDim Shape.scalar))) (Wo : Tensor α (Shape.dim projDim (Shape.dim dModel Shape.scalar))) (h1 : n1 0) (h2 : n2 0) :

                                            Cross-attention between two sequences.

                                            query is length n1 and attends to key/value of length n2.

                                            PyTorch mental model: the attention block in a Transformer decoder layer (nn.MultiheadAttention with distinct query and key/value inputs).

                                            Instances For
                                              def Spec.sparseAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n dModel projDim : } (x : Tensor α (Shape.dim n (Shape.dim dModel Shape.scalar))) (sparsityPattern : Tensor Bool (Shape.dim n (Shape.dim n Shape.scalar))) (Wq Wk Wv : Tensor α (Shape.dim dModel (Shape.dim projDim Shape.scalar))) (Wo : Tensor α (Shape.dim projDim (Shape.dim dModel Shape.scalar))) (h1 : n 0) :

                                              Sparse Attention Uses sparse attention patterns for efficiency

                                              Instances For