TorchLean API

NN.Spec.Models.Transformer

Transformer (spec model) #

This file defines a small Transformer-style model in a way that matches the usual PyTorch mental model:

Shapes follow the common convention:

PyTorch analogy:

References:

PyTorch docs (for API shape intuition, not semantics):

Configuration helpers #

This file mostly defines reusable transformer building blocks (encoder/decoder layers, attention, layer-norm wrappers, etc.). To make "model-zoo style" instantiations easier, we also provide a small config record for the common hyperparameters together with a couple of canonical configs (Base/Big).

The core definitions below still expose the hyperparameters as Nat parameters so the spec remains flexible; the config layer is a convenience wrapper, not a new abstraction boundary.

Common transformer layer hyperparameters.

  • headCount :

    Number of attention heads.

  • embedDim :

    Embedding dimension (d_model).

  • hiddenDim :

    Feedforward hidden dimension (d_ff).

Instances For

    Stack hyperparameters for an encoder/decoder: common layer config plus a layer count.

    Instances For

      Well-formedness conditions for TransformerLayerConfig.

      The divisibility condition keeps the per-head width exact: embedDim / headCount should partition the model dimension without silently dropping a tail through Nat floor division.

      Instances For

        Well-formedness conditions for TransformerStackConfig.

        • layer : cfg.WF
        Instances For

          Canonical Transformer "base" hyperparameters (Vaswani et al. 2017).

          Instances For

            Canonical Transformer "big" hyperparameters (Vaswani et al. 2017).

            Instances For

              Gradient containers #

              To keep the backward pass readable (and easy to reuse from downstream models like ViT/Seq2Seq), we bundle parameter gradients into records that mirror the parameter records.

              structure Spec.FeedForwardGrads (embedDim hiddenDim : ) (α : Type) :

              Gradients for a FeedForward block (field-for-field).

              This is a lightweight container used by downstream models that want a readable backward pass.

              Instances For
                structure Spec.MultiHeadAttentionGrads (numHeads dModel headDim : ) (α : Type) :

                Gradients for MultiHeadAttention parameters (field-for-field).

                This mirrors the MultiHeadAttention record defined in NN.Spec.Module.Attention.

                Instances For
                  structure Spec.TransformerEncoderLayerGrads (headCount embedDim hiddenDim : ) (α : Type) :

                  Gradients for a TransformerEncoderLayer (field-for-field).

                  This container is intended to keep the backward pass readable by mirroring the parameter layout.

                  Instances For
                    structure Spec.FeedForward (embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                    2-layer position-wise feedforward network used inside Transformer layers.

                    Semantics (per token): ffn(x) = (relu(x * W1 + b1) * W2) + b2.

                    PyTorch analogue: the linear1 / linear2 submodule in torch.nn.TransformerEncoderLayer.

                    Instances For
                      def Spec.FeedForward.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {embedDim hiddenDim seqLen : } (ffn : FeedForward embedDim hiddenDim α) (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :
                      Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                      Forward pass for FeedForward.

                      Shape convention: inputs and outputs are (seqLen × embedDim); the feedforward operates independently on each sequence position.

                      Instances For
                        structure Spec.TransformerEncoderLayer (headCount embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                        Transformer encoder layer (post-norm).

                        This follows the common "Add & Norm" structure:

                        1. Self-attention, residual add, LayerNorm
                        2. Feedforward, residual add, LayerNorm

                        PyTorch analogue: torch.nn.TransformerEncoderLayer with norm_first=False (post-norm), ignoring dropout and other configuration knobs.

                        Instances For
                          def Spec.TransformerEncoderLayer.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim hiddenDim seqLen : } (layer : TransformerEncoderLayer headCount embedDim hiddenDim α) (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                          Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                          Forward pass for a post-norm TransformerEncoderLayer.

                          Input/output shape: (seqLen × embedDim). The proofs h1/h2 are used by layerNorm to justify nondegenerate normalization.

                          Instances For
                            structure Spec.TransformerEncoder (numLayers headCount embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                            Transformer encoder: a stack of TransformerEncoderLayers.

                            PyTorch analogue: torch.nn.TransformerEncoder (a list of layers composed sequentially).

                            Instances For
                              def Spec.TransformerEncoder.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {numLayers headCount embedDim hiddenDim seqLen : } (encoder : TransformerEncoder numLayers headCount embedDim hiddenDim α) (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                              Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                              Forward pass for TransformerEncoder (left-fold over layers).

                              Input/output shape: (seqLen × embedDim).

                              Instances For

                                Config-indexed aliases #

                                These are small convenience abbreviations so downstream models can index transformer components by a config record rather than repeating the Nat parameters.

                                @[reducible, inline]

                                Encoder-layer gradients indexed by a TransformerLayerConfig.

                                Instances For
                                  @[reducible, inline]
                                  abbrev Spec.TransformerEncoderCfg (cfg : TransformerStackConfig) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                                  Encoder stack indexed by a TransformerStackConfig.

                                  Instances For

                                    Decoder notes #

                                    We include a small Transformer-style decoder layer for completeness:

                                    PyTorch analogy: this corresponds to the core of torch.nn.TransformerDecoderLayer (ignoring dropout and a few configuration knobs).

                                    Cross-attention helper #

                                    The attention layer provides MultiHeadAttention.forward for the common self-attention case (Q=K=V=x). A decoder block also needs cross-attention, where Q comes from the decoder stream and K,V come from the encoder stream.

                                    We keep the helper here small and explicit by following the same structure as the self-attention definition: project, split into heads, run scaled dot-product attention per head, combine heads, then project with Wo.

                                    def Spec.multiHeadCrossAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim nQ nK : } (hQ : nQ 0) (hK : nK 0) (mha : MultiHeadAttention α headCount embedDim (embedDim / headCount)) (qInput : Tensor α (Shape.dim nQ (Shape.dim embedDim Shape.scalar))) (kvInput : Tensor α (Shape.dim nK (Shape.dim embedDim Shape.scalar))) (mask : Option (Tensor Bool (Shape.dim nQ (Shape.dim nK Shape.scalar)))) :

                                    Cross-attention forward pass using a MultiHeadAttention parameter record.

                                    This is the decoder-specific variant of MultiHeadAttention.forward:

                                    • queries come from qInput (decoder stream),
                                    • keys/values come from kvInput (encoder stream),
                                    • an optional boolean mask of shape (nQ × nK) can be applied.

                                    Shape conventions:

                                    • qInput : (nQ × embedDim),
                                    • kvInput : (nK × embedDim),
                                    • output : (nQ × embedDim).

                                    PyTorch analogue: the cross-attention inside torch.nn.TransformerDecoderLayer, typically implemented via torch.nn.MultiheadAttention with separate query and key/value inputs.

                                    Instances For
                                      structure Spec.TransformerDecoderLayer (headCount embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                                      Transformer decoder layer (post-norm).

                                      This mirrors the standard structure:

                                      1. Self-attention (decoder stream), residual add, LayerNorm
                                      2. Cross-attention (queries from decoder, keys/values from encoder), residual add, LayerNorm
                                      3. Feedforward, residual add, LayerNorm

                                      PyTorch analogue: torch.nn.TransformerDecoderLayer with norm_first=False (post-norm), ignoring dropout and a few configuration knobs.

                                      Instances For
                                        def Spec.TransformerDecoderLayer.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim hiddenDim seqLen : } (layer : TransformerDecoderLayer headCount embedDim hiddenDim α) (x encoderOutput : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                                        Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                        Forward pass for a post-norm TransformerDecoderLayer.

                                        Input/output shape: (seqLen × embedDim). This spec uses the same seqLen for encoder and decoder streams for simplicity (cross-attention uses nQ = nK = seqLen).

                                        Instances For
                                          structure Spec.TransformerDecoder (numLayers headCount embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                                          Transformer decoder: a stack of TransformerDecoderLayers.

                                          PyTorch analogue: torch.nn.TransformerDecoder (a list of decoder layers composed sequentially).

                                          Instances For
                                            def Spec.TransformerDecoder.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {numLayers headCount embedDim hiddenDim seqLen : } (decoder : TransformerDecoder numLayers headCount embedDim hiddenDim α) (x encoderOutput : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                                            Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                            Forward pass for TransformerDecoder (left-fold over layers).

                                            Input/output shape: (seqLen × embedDim).

                                            Instances For
                                              structure Spec.Transformer (numLayers headCount embedDim hiddenDim : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :

                                              End-to-end encoder-decoder Transformer (spec model).

                                              This is a seq2seq Transformer wrapper built out of the encoder and decoder stacks above. Compared to torch.nn.Transformer, it is intentionally simplified:

                                              • embeddings are modeled as explicit linear projections,
                                              • sequence length is shared between source and target streams,
                                              • we omit dropout, caching, and most configuration knobs.

                                              Shape convention: all activations in this file use (seqLen × embedDim). In a full implementation, outputProjection would usually map to a vocabulary size; here it is kept as an embedDim -> embedDim projection to stay in the "core tensor algebra" setting.

                                              Instances For
                                                def Spec.Transformer.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {numLayers headCount embedDim hiddenDim seqLen : } (transformer : Transformer numLayers headCount embedDim hiddenDim α) (input target : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                                                Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                Forward pass for Transformer.

                                                Runs:

                                                1. source embedding projection,
                                                2. encoder stack,
                                                3. target embedding projection,
                                                4. decoder stack (with cross-attention to the encoder output),
                                                5. output projection.

                                                All tensors in this simplified spec have shape (seqLen × embedDim).

                                                Instances For
                                                  def Spec.FeedForward.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {embedDim hiddenDim seqLen : } (ffn : FeedForward embedDim hiddenDim α) (x outputGrad : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h_seq : seqLen > 0) (_h_embed : embedDim > 0) :
                                                  FeedForwardGrads embedDim hiddenDim α × Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                  Backward pass for FeedForward.forward.

                                                  Given the input x and an upstream gradient outputGrad = dL/dy (w.r.t. the FFN output), returns:

                                                  This is a spec-level backward that reconstructs the forward intermediates (pre-activations and ReLU mask) instead of relying on a mutable tape, similar to the math underlying PyTorch autograd.

                                                  Instances For
                                                    def Spec.TransformerEncoderLayer.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim hiddenDim seqLen : } (layer : TransformerEncoderLayer headCount embedDim hiddenDim α) (x outputGrad : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                                                    TransformerEncoderLayerGrads headCount embedDim hiddenDim α × Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                    Backward pass for TransformerEncoderLayer.forward.

                                                    Inputs:

                                                    • x: the layer input (seqLen × embedDim),
                                                    • outputGrad: upstream gradient w.r.t. the layer output.

                                                    Outputs:

                                                    The implementation mirrors the forward pass structure (residuals + LayerNorm) and uses layerNorm_backward and MultiHeadAttention_backward as its core primitives.

                                                    Instances For

                                                      Backward pass for an encoder stack #

                                                      The encoder is a list of layers applied sequentially. To compute gradients we:

                                                      1. re-run the forward pass to collect each layer's input (a small "cache"),
                                                      2. traverse layers in reverse, applying TransformerEncoderLayer.backward,
                                                      3. return per-layer parameter gradients plus the gradient w.r.t. the encoder input.

                                                      This is purely a spec (no mutation, no state), so we do the simplest thing: recompute.

                                                      def Spec.TransformerEncoder.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {numLayers headCount embedDim hiddenDim seqLen : } (encoder : TransformerEncoder numLayers headCount embedDim hiddenDim α) (x outputGrad : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h1 : seqLen > 0) (h2 : embedDim > 0) :
                                                      List (TransformerEncoderLayerGrads headCount embedDim hiddenDim α) × Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                      Backward pass for TransformerEncoder.forward (a sequential stack of layers).

                                                      Returns:

                                                      • a list of per-layer parameter gradients (in the same order as encoder.layers),
                                                      • the gradient w.r.t. the encoder input x.

                                                      Because this is a pure spec, we recompute forward intermediates (each layer input) instead of storing a mutable cache.

                                                      Instances For
                                                        def Spec.TransformerEncoder.backward.collect_inputs {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim hiddenDim seqLen : } (h1 : seqLen > 0) (h2 : embedDim > 0) (layers : List (TransformerEncoderLayer headCount embedDim hiddenDim α)) (cur : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :
                                                        List (Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar)))
                                                        Instances For
                                                          def Spec.positionalEncoding {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :
                                                          Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                          Sinusoidal positional encoding (Vaswani et al., 2017), added to the input sequence.

                                                          Given x : (seqLen × embedDim), returns x + pe where pe[pos, i] alternates sin/cos features with geometrically-spaced frequencies.

                                                          PyTorch analogue: positional encodings are often applied externally in PyTorch examples; the high-level torch.nn.Transformer module does not force a particular encoding.

                                                          Instances For
                                                            def Spec.maskedMultiHeadAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {headCount embedDim seqLen : } (mha : MultiHeadAttention α headCount embedDim (embedDim / headCount)) (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (mask : Option (Tensor Bool (Shape.dim seqLen (Shape.dim seqLen Shape.scalar)))) (h1 : seqLen > 0) :
                                                            Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                            Multi-head self-attention with an optional boolean mask.

                                                            This is a thin wrapper around MultiHeadAttention.forward that:

                                                            • derives the required seqLen ≠ 0 proof from h1 : seqLen > 0,
                                                            • forwards the provided mask (typically a causal mask for autoregressive decoding).

                                                            PyTorch analogue: masked self-attention in torch.nn.TransformerDecoderLayer implemented via torch.nn.MultiheadAttention(..., attn_mask=...).

                                                            Instances For