Vit #
Vision Transformer (ViT) model.
This is a compact “ViT-style” specification:
- patch embedding via
Conv2D(kernel = patch size), - flatten patches into a token sequence,
- add a learnable positional encoding,
- run a Transformer encoder,
- mean-pool tokens and apply a linear classifier head.
Notes:
- PyTorch mental model: this corresponds to the core dataflow of
torchvision.models.vit_*, but written without batching: tensors are(C,H,W)images and(T,D)token sequences. - This file provides both mean-pool (
ViTSpec) and CLS-token (ViTClsSpec) variants. The CLS-token variant prepends one learnable token before the encoder and pools by taking token0. - We intentionally keep the patch embedding as a
Conv2dwithkernel_size=(patchH,patchW). Whenstride=(patchH,patchW)andpadding=0, that matches the usual "non-overlapping patches" embedding used in many ViT implementations.
Output height of the patch-embedding convolution in ViT.
Instances For
Output width of the patch-embedding convolution in ViT.
Instances For
Number of patch tokens T = outH*outW produced by the patch embedding.
Instances For
Configuration #
We keep ViT architectural hyperparameters in a dedicated config record so the model definition does not hide numeric choices in its types. This mirrors the usual config-object pattern in PyTorch/torchvision model-zoo code.
ViT architectural hyperparameters (spec layer).
- patchH : ℕ
Patch height (kernel height for the patch-embedding conv).
- patchW : ℕ
Patch width (kernel width for the patch-embedding conv).
- stride : ℕ
Stride for the patch-embedding conv (typical: equal to patch size for non-overlapping patches).
- padding : ℕ
Padding for the patch-embedding conv (typical:
0). - embedDim : ℕ
Transformer embedding dimension (
d_model). - headCount : ℕ
Number of attention heads.
- numLayers : ℕ
Number of encoder layers.
- numClasses : ℕ
Output classes for the classifier head.
Instances For
Well-formedness conditions for ViTConfig (the nonzero facts needed by some layer specs).
Instances For
Classic ViT-Base/16-ish hyperparameters (mean-pool variant; spec layer).
Instances For
vitBasePatch16Config satisfies ViTConfig.WF.
Classic ViT-Large/16-ish hyperparameters (mean-pool variant; spec layer).
Instances For
ViT parameter bundle (patch embedding + positional encoding + transformer + head).
- posEnc : Spec.PositionalEncodingSpec (ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding) cfg.embedDim α
- encoder : Spec.TransformerEncoder cfg.numLayers cfg.headCount cfg.embedDim cfg.hiddenDim α
- head : Spec.LinearSpec α cfg.embedDim cfg.numClasses
Instances For
Forward pass (patches -> tokens -> encoder -> head) #
The forward is the standard ViT dataflow, but with explicit shape transforms so it stays obvious what each axis means:
conv2d_specproduces a feature map(embedDim, outH, outW).- We flatten
(outH, outW)into a single token axistokN = outH*outW. - We swap to token-major layout
(tokN, embedDim)(this is the usual transformer convention). - We add positional embeddings and run the transformer encoder.
- We mean-pool tokens and apply a final linear classifier.
PyTorch analogy (no batch axis here):
- patch embedding:
Conv2d(inC, embedDim, kernel_size=patch, stride=stride, padding=padding) - flatten:
x.flatten(1).transpose(0, 1)to get(T,D)depending on your convention - encoder:
TransformerEncoder(...) - pooling + head:
encoded.mean(dim=0)thenLinear(embedDim, numClasses)
Gradients for the compact ViT spec (matching ViTSpec).
- d_patch_kernel : Spec.Tensor α (Spec.Shape.dim cfg.embedDim (Spec.Shape.dim inC (Spec.Shape.dim cfg.patchH (Spec.Shape.dim cfg.patchW Spec.Shape.scalar))))
- d_patch_bias : Spec.Tensor α (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar)
- d_pos : Spec.Tensor α (Spec.Shape.dim (ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding) (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar))
- d_encoder : List (Spec.TransformerEncoderLayerGrads cfg.headCount cfg.embedDim cfg.hiddenDim α)
- d_head_W : Spec.Tensor α (Spec.Shape.dim cfg.numClasses (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar))
- d_head_b : Spec.Tensor α (Spec.Shape.dim cfg.numClasses Spec.Shape.scalar)
Instances For
ViT forward pass (patch embedding → tokens → transformer encoder → pool → head).
Instances For
Backward pass #
This is a fully explicit reverse-mode spec (no meta-autograd):
- patch embedding:
Conv2Dbackward gives∂kernel,∂bias, and∂image, - positional encoding: addition splits gradient (
∂pos = ∂tokens), - transformer encoder:
TransformerEncoder.backward(inNN/Spec/Models/Transformer.lean), - mean pooling over tokens: broadcast + scale by
1/tokN, - classifier head:
linear_backward_spec.
We recompute intermediates locally; this keeps the spec self-contained and avoids adding a global "tape" type for every model.
Fully explicit reverse-mode backward pass for ViTSpec.forward.
Instances For
CLS-token ViT variant (classic pooling) #
Many ViT implementations (including the original ViT paper and torchvision.models.vit_*)
use a learnable CLS token:
- prepend
clsTokento the patch-token sequence, - use positional encodings of length
tokN + 1, - run the encoder on a sequence of length
tokN + 1, - take token
0after the encoder as the pooled representation, then apply the head.
We keep the existing mean-pool ViTSpec unchanged; this is a separate parameter bundle and
explicit backward pass.
ViT parameter bundle with a learnable CLS token (classic ViT variant).
- clsToken : Spec.Tensor α (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar)
Learnable CLS token embedding (prepended as token 0).
- posEnc : Spec.PositionalEncodingSpec (ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding + 1) cfg.embedDim α
- encoder : Spec.TransformerEncoder cfg.numLayers cfg.headCount cfg.embedDim cfg.hiddenDim α
- head : Spec.LinearSpec α cfg.embedDim cfg.numClasses
Instances For
Gradients for the CLS-token ViT spec (matching ViTClsSpec).
- d_patch_kernel : Spec.Tensor α (Spec.Shape.dim cfg.embedDim (Spec.Shape.dim inC (Spec.Shape.dim cfg.patchH (Spec.Shape.dim cfg.patchW Spec.Shape.scalar))))
- d_patch_bias : Spec.Tensor α (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar)
- d_clsToken : Spec.Tensor α (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar)
- d_pos : Spec.Tensor α (Spec.Shape.dim (ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding + 1) (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar))
- d_encoder : List (Spec.TransformerEncoderLayerGrads cfg.headCount cfg.embedDim cfg.hiddenDim α)
- d_head_W : Spec.Tensor α (Spec.Shape.dim cfg.numClasses (Spec.Shape.dim cfg.embedDim Spec.Shape.scalar))
- d_head_b : Spec.Tensor α (Spec.Shape.dim cfg.numClasses Spec.Shape.scalar)
Instances For
CLS-token ViT forward pass (prepend CLS → transformer encoder → take token 0 → head).
Instances For
Fully explicit reverse-mode backward pass for ViTClsSpec.forward.