Transformer (spec model) #
This file defines a small Transformer-style model in a way that matches the usual PyTorch mental model:
- encoder layers (self-attention + FFN, each wrapped in residual + LayerNorm),
- decoder layers (masked self-attention, cross-attention, FFN, each wrapped in residual + LayerNorm),
- an encoder-decoder wrapper (
Transformer), - spec-level backward passes for the encoder stack,
- small utilities like sinusoidal positional encodings and causal masks.
Shapes follow the common convention:
- sequence tensors are
(seqLen × embedDim), - attention is "last-axis softmax" over the key dimension.
PyTorch analogy:
TransformerEncoderLayer.forwardcorresponds to the core oftorch.nn.TransformerEncoderLayer(ignoring dropout and some configuration knobs),TransformerEncoder.forwardcorresponds totorch.nn.TransformerEncoder.TransformerDecoderLayer.forwardcorresponds to the core oftorch.nn.TransformerDecoderLayer,TransformerDecoder.forwardcorresponds totorch.nn.TransformerDecoder,Transformer.forwardis similar in spirit totorch.nn.Transformer(but simplified).
References:
- Vaswani et al., "Attention Is All You Need" (2017).
- Ba et al., "Layer Normalization" (2016).
- He et al., "Deep Residual Learning for Image Recognition" (2015) for the residual/skip-connection pattern.
PyTorch docs (for API shape intuition, not semantics):
torch.nn.TransformerEncoderLayer: https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.htmltorch.nn.TransformerEncoder: https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.htmltorch.nn.TransformerDecoderLayer: https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.htmltorch.nn.TransformerDecoder: https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.htmltorch.nn.Transformer: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
Configuration helpers #
This file mostly defines reusable transformer building blocks (encoder/decoder layers, attention, layer-norm wrappers, etc.). To make "model-zoo style" instantiations easier, we also provide a small config record for the common hyperparameters together with a couple of canonical configs (Base/Big).
The core definitions below still expose the hyperparameters as Nat parameters so the spec remains flexible; the config layer is a convenience wrapper, not a new abstraction boundary.
Common transformer layer hyperparameters.
Instances For
Well-formedness conditions for TransformerLayerConfig.
The divisibility condition keeps the per-head width exact: embedDim / headCount should partition
the model dimension without silently dropping a tail through Nat floor division.
Instances For
Canonical Transformer "base" hyperparameters (Vaswani et al. 2017).
Instances For
transformerBaseConfig is well-formed.
Canonical Transformer "big" hyperparameters (Vaswani et al. 2017).
Instances For
transformerBigConfig is well-formed.
Gradient containers #
To keep the backward pass readable (and easy to reuse from downstream models like ViT/Seq2Seq), we bundle parameter gradients into records that mirror the parameter records.
Gradients for a FeedForward block (field-for-field).
This is a lightweight container used by downstream models that want a readable backward pass.
- dW1 : Tensor α (Shape.dim embedDim (Shape.dim hiddenDim Shape.scalar))
Gradient of
W1. - dW2 : Tensor α (Shape.dim hiddenDim (Shape.dim embedDim Shape.scalar))
Gradient of
W2. - db1 : Tensor α (Shape.dim hiddenDim Shape.scalar)
Gradient of
b1. - db2 : Tensor α (Shape.dim embedDim Shape.scalar)
Gradient of
b2.
Instances For
Gradients for MultiHeadAttention parameters (field-for-field).
This mirrors the MultiHeadAttention record defined in NN.Spec.Module.Attention.
- dWq : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Gradient of the query projection matrix
Wq. - dWk : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Gradient of the key projection matrix
Wk. - dWv : Tensor α (Shape.dim dModel (Shape.dim (numHeads * headDim) Shape.scalar))
Gradient of the value projection matrix
Wv. - dWo : Tensor α (Shape.dim (numHeads * headDim) (Shape.dim dModel Shape.scalar))
Gradient of the output projection matrix
Wo.
Instances For
Gradients for a TransformerEncoderLayer (field-for-field).
This container is intended to keep the backward pass readable by mirroring the parameter layout.
- mha : MultiHeadAttentionGrads headCount embedDim (embedDim / headCount) α
Gradients for the self-attention block.
- ffn : FeedForwardGrads embedDim hiddenDim α
Gradients for the feedforward block.
- d_norm1_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
Gradient of LayerNorm 1 gamma (attention "Add & Norm").
- d_norm1_beta : Tensor α (Shape.dim embedDim Shape.scalar)
Gradient of LayerNorm 1 beta (attention "Add & Norm").
- d_norm2_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
Gradient of LayerNorm 2 gamma (FFN "Add & Norm").
- d_norm2_beta : Tensor α (Shape.dim embedDim Shape.scalar)
Gradient of LayerNorm 2 beta (FFN "Add & Norm").
Instances For
2-layer position-wise feedforward network used inside Transformer layers.
Semantics (per token):
ffn(x) = (relu(x * W1 + b1) * W2) + b2.
PyTorch analogue: the linear1 / linear2 submodule in torch.nn.TransformerEncoderLayer.
- W1 : Tensor α (Shape.dim embedDim (Shape.dim hiddenDim Shape.scalar))
- W2 : Tensor α (Shape.dim hiddenDim (Shape.dim embedDim Shape.scalar))
- b1 : Tensor α (Shape.dim hiddenDim Shape.scalar)
First layer bias (length
hiddenDim). - b2 : Tensor α (Shape.dim embedDim Shape.scalar)
Second layer bias (length
embedDim).
Instances For
Forward pass for FeedForward.
Shape convention: inputs and outputs are (seqLen × embedDim); the feedforward operates
independently on each sequence position.
Instances For
Transformer encoder layer (post-norm).
This follows the common "Add & Norm" structure:
- Self-attention, residual add, LayerNorm
- Feedforward, residual add, LayerNorm
PyTorch analogue: torch.nn.TransformerEncoderLayer with norm_first=False (post-norm),
ignoring dropout and other configuration knobs.
- mha : MultiHeadAttention α headCount embedDim (embedDim / headCount)
Multi-head self-attention block.
- ffn : FeedForward embedDim hiddenDim α
Position-wise feedforward block.
- norm1_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 1 gamma (attention "Add & Norm").
- norm1_beta : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 1 beta (attention "Add & Norm").
- norm2_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 2 gamma (FFN "Add & Norm").
- norm2_beta : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 2 beta (FFN "Add & Norm").
Instances For
Forward pass for a post-norm TransformerEncoderLayer.
Input/output shape: (seqLen × embedDim).
The proofs h1/h2 are used by layerNorm to justify nondegenerate normalization.
Instances For
Transformer encoder: a stack of TransformerEncoderLayers.
PyTorch analogue: torch.nn.TransformerEncoder (a list of layers composed sequentially).
- layers : List (TransformerEncoderLayer headCount embedDim hiddenDim α)
Layer list; typically has length
numLayers, but the spec does not enforce that invariant.
Instances For
Forward pass for TransformerEncoder (left-fold over layers).
Input/output shape: (seqLen × embedDim).
Instances For
Config-indexed aliases #
These are small convenience abbreviations so downstream models can index transformer components by a config record rather than repeating the Nat parameters.
Encoder stack indexed by a TransformerStackConfig.
Instances For
Decoder notes #
We include a small Transformer-style decoder layer for completeness:
- self-attention over the decoder sequence,
- cross-attention where queries come from the decoder and keys/values come from the encoder,
- then the same feedforward block.
PyTorch analogy: this corresponds to the core of torch.nn.TransformerDecoderLayer (ignoring
dropout and a few configuration knobs).
Cross-attention helper #
The attention layer provides MultiHeadAttention.forward for the common self-attention case
(Q=K=V=x). A decoder block also needs cross-attention, where Q comes from the decoder stream
and K,V come from the encoder stream.
We keep the helper here small and explicit by following the same structure as the self-attention
definition: project, split into heads, run scaled dot-product attention per head, combine heads,
then project with Wo.
Cross-attention forward pass using a MultiHeadAttention parameter record.
This is the decoder-specific variant of MultiHeadAttention.forward:
- queries come from
qInput(decoder stream), - keys/values come from
kvInput(encoder stream), - an optional boolean mask of shape
(nQ × nK)can be applied.
Shape conventions:
qInput : (nQ × embedDim),kvInput : (nK × embedDim),- output :
(nQ × embedDim).
PyTorch analogue: the cross-attention inside torch.nn.TransformerDecoderLayer, typically
implemented via torch.nn.MultiheadAttention with separate query and key/value inputs.
Instances For
Transformer decoder layer (post-norm).
This mirrors the standard structure:
- Self-attention (decoder stream), residual add, LayerNorm
- Cross-attention (queries from decoder, keys/values from encoder), residual add, LayerNorm
- Feedforward, residual add, LayerNorm
PyTorch analogue: torch.nn.TransformerDecoderLayer with norm_first=False (post-norm),
ignoring dropout and a few configuration knobs.
- selfAttn : MultiHeadAttention α headCount embedDim (embedDim / headCount)
Self-attention block over the decoder sequence.
- crossAttn : MultiHeadAttention α headCount embedDim (embedDim / headCount)
Cross-attention block (decoder queries, encoder keys/values).
- ffn : FeedForward embedDim hiddenDim α
Position-wise feedforward block.
- norm1_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 1 gamma (self-attention "Add & Norm").
- norm1_beta : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 1 beta (self-attention "Add & Norm").
- norm2_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 2 gamma (cross-attention "Add & Norm").
- norm2_beta : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 2 beta (cross-attention "Add & Norm").
- norm3_gamma : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 3 gamma (FFN "Add & Norm").
- norm3_beta : Tensor α (Shape.dim embedDim Shape.scalar)
LayerNorm 3 beta (FFN "Add & Norm").
Instances For
Forward pass for a post-norm TransformerDecoderLayer.
Input/output shape: (seqLen × embedDim). This spec uses the same seqLen for encoder and decoder
streams for simplicity (cross-attention uses nQ = nK = seqLen).
Instances For
Transformer decoder: a stack of TransformerDecoderLayers.
PyTorch analogue: torch.nn.TransformerDecoder (a list of decoder layers composed sequentially).
- layers : List (TransformerDecoderLayer headCount embedDim hiddenDim α)
Layer list; typically has length
numLayers, but the spec does not enforce that invariant.
Instances For
Forward pass for TransformerDecoder (left-fold over layers).
Input/output shape: (seqLen × embedDim).
Instances For
End-to-end encoder-decoder Transformer (spec model).
This is a seq2seq Transformer wrapper built out of the encoder and decoder stacks above.
Compared to torch.nn.Transformer, it is intentionally simplified:
- embeddings are modeled as explicit linear projections,
- sequence length is shared between source and target streams,
- we omit dropout, caching, and most configuration knobs.
Shape convention: all activations in this file use (seqLen × embedDim).
In a full implementation, outputProjection would usually map to a vocabulary size; here it is
kept as an embedDim -> embedDim projection to stay in the "core tensor algebra" setting.
- encoder : TransformerEncoder numLayers headCount embedDim hiddenDim α
Encoder stack.
- decoder : TransformerDecoder numLayers headCount embedDim hiddenDim α
Decoder stack.
- inputEmbedding : Tensor α (Shape.dim embedDim (Shape.dim embedDim Shape.scalar))
Source/input embedding projection matrix.
- outputEmbedding : Tensor α (Shape.dim embedDim (Shape.dim embedDim Shape.scalar))
Target embedding projection matrix.
- outputProjection : Tensor α (Shape.dim embedDim (Shape.dim embedDim Shape.scalar))
Instances For
Forward pass for Transformer.
Runs:
- source embedding projection,
- encoder stack,
- target embedding projection,
- decoder stack (with cross-attention to the encoder output),
- output projection.
All tensors in this simplified spec have shape (seqLen × embedDim).
Instances For
Backward pass for FeedForward.forward.
Given the input x and an upstream gradient outputGrad = dL/dy (w.r.t. the FFN output),
returns:
- parameter gradients (as
FeedForwardGrads), - the gradient w.r.t. the input
x.
This is a spec-level backward that reconstructs the forward intermediates (pre-activations and ReLU mask) instead of relying on a mutable tape, similar to the math underlying PyTorch autograd.
Instances For
Backward pass for TransformerEncoderLayer.forward.
Inputs:
x: the layer input(seqLen × embedDim),outputGrad: upstream gradient w.r.t. the layer output.
Outputs:
- parameter gradients (
TransformerEncoderLayerGrads), - gradient w.r.t.
x.
The implementation mirrors the forward pass structure (residuals + LayerNorm) and uses
layerNorm_backward and MultiHeadAttention_backward as its core primitives.
Instances For
Backward pass for an encoder stack #
The encoder is a list of layers applied sequentially. To compute gradients we:
- re-run the forward pass to collect each layer's input (a small "cache"),
- traverse layers in reverse, applying
TransformerEncoderLayer.backward, - return per-layer parameter gradients plus the gradient w.r.t. the encoder input.
This is purely a spec (no mutation, no state), so we do the simplest thing: recompute.
Backward pass for TransformerEncoder.forward (a sequential stack of layers).
Returns:
- a list of per-layer parameter gradients (in the same order as
encoder.layers), - the gradient w.r.t. the encoder input
x.
Because this is a pure spec, we recompute forward intermediates (each layer input) instead of storing a mutable cache.
Instances For
Instances For
Sinusoidal positional encoding (Vaswani et al., 2017), added to the input sequence.
Given x : (seqLen × embedDim), returns x + pe where pe[pos, i] alternates sin/cos
features with geometrically-spaced frequencies.
PyTorch analogue: positional encodings are often applied externally in PyTorch examples; the
high-level torch.nn.Transformer module does not force a particular encoding.
Instances For
Multi-head self-attention with an optional boolean mask.
This is a thin wrapper around MultiHeadAttention.forward that:
- derives the required
seqLen ≠ 0proof fromh1 : seqLen > 0, - forwards the provided
mask(typically a causal mask for autoregressive decoding).
PyTorch analogue: masked self-attention in torch.nn.TransformerDecoderLayer implemented via
torch.nn.MultiheadAttention(..., attn_mask=...).