TorchLean API

NN.Spec.Models.Vit

Vit #

Vision Transformer (ViT) model.

This is a compact “ViT-style” specification:

Notes:

@[reducible, inline]
abbrev Models.ViTPatchOutH (inH patchH stride padding : ) :

Output height of the patch-embedding convolution in ViT.

Instances For
    @[reducible, inline]
    abbrev Models.ViTPatchOutW (inW patchW stride padding : ) :

    Output width of the patch-embedding convolution in ViT.

    Instances For
      @[reducible, inline]
      abbrev Models.ViTPatchCount (inH inW patchH patchW stride padding : ) :

      Number of patch tokens T = outH*outW produced by the patch embedding.

      Instances For

        Configuration #

        We keep ViT architectural hyperparameters in a dedicated config record so the model definition does not hide numeric choices in its types. This mirrors the usual config-object pattern in PyTorch/torchvision model-zoo code.

        ViT architectural hyperparameters (spec layer).

        • patchH :

          Patch height (kernel height for the patch-embedding conv).

        • patchW :

          Patch width (kernel width for the patch-embedding conv).

        • stride :

          Stride for the patch-embedding conv (typical: equal to patch size for non-overlapping patches).

        • padding :

          Padding for the patch-embedding conv (typical: 0).

        • embedDim :

          Transformer embedding dimension (d_model).

        • hiddenDim :

          Transformer feedforward hidden dimension (d_ff).

        • headCount :

          Number of attention heads.

        • numLayers :

          Number of encoder layers.

        • numClasses :

          Output classes for the classifier head.

        Instances For

          Well-formedness conditions for ViTConfig (the nonzero facts needed by some layer specs).

          Instances For

            Classic ViT-Base/16-ish hyperparameters (mean-pool variant; spec layer).

            Instances For

              Classic ViT-Large/16-ish hyperparameters (mean-pool variant; spec layer).

              Instances For
                structure Models.ViTSpec (cfg : ViTConfig) (inC inH inW : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (h_inC : inC 0) (hCfg : cfg.WF) :

                ViT parameter bundle (patch embedding + positional encoding + transformer + head).

                Instances For

                  Forward pass (patches -> tokens -> encoder -> head) #

                  The forward is the standard ViT dataflow, but with explicit shape transforms so it stays obvious what each axis means:

                  1. conv2d_spec produces a feature map (embedDim, outH, outW).
                  2. We flatten (outH, outW) into a single token axis tokN = outH*outW.
                  3. We swap to token-major layout (tokN, embedDim) (this is the usual transformer convention).
                  4. We add positional embeddings and run the transformer encoder.
                  5. We mean-pool tokens and apply a final linear classifier.

                  PyTorch analogy (no batch axis here):

                  def Models.ViTSpec.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : ViTConfig} {inC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : ViTSpec cfg inC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) (h_tok : ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding > 0) :

                  ViT forward pass (patch embedding → tokens → transformer encoder → pool → head).

                  Instances For

                    Backward pass #

                    This is a fully explicit reverse-mode spec (no meta-autograd):

                    We recompute intermediates locally; this keeps the spec self-contained and avoids adding a global "tape" type for every model.

                    def Models.ViTSpec.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : ViTConfig} {inC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : ViTSpec cfg inC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) (grad_output : Spec.Tensor α (Spec.Shape.dim cfg.numClasses Spec.Shape.scalar)) (h_tok : ViTPatchCount inH inW cfg.patchH cfg.patchW cfg.stride cfg.padding > 0) :
                    ViTGrads cfg inC inH inW α × Spec.MultiChannelImage inC inH inW α

                    Fully explicit reverse-mode backward pass for ViTSpec.forward.

                    Instances For

                      CLS-token ViT variant (classic pooling) #

                      Many ViT implementations (including the original ViT paper and torchvision.models.vit_*) use a learnable CLS token:

                      We keep the existing mean-pool ViTSpec unchanged; this is a separate parameter bundle and explicit backward pass.

                      structure Models.ViTClsSpec (cfg : ViTConfig) (inC inH inW : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (h_inC : inC 0) (hCfg : cfg.WF) :

                      ViT parameter bundle with a learnable CLS token (classic ViT variant).

                      Instances For
                        def Models.ViTClsSpec.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : ViTConfig} {inC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : ViTClsSpec cfg inC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) :

                        CLS-token ViT forward pass (prepend CLS → transformer encoder → take token 0 → head).

                        Instances For
                          def Models.ViTClsSpec.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : ViTConfig} {inC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : ViTClsSpec cfg inC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) (grad_output : Spec.Tensor α (Spec.Shape.dim cfg.numClasses Spec.Shape.scalar)) :
                          ViTClsGrads cfg inC inH inW α × Spec.MultiChannelImage inC inH inW α

                          Fully explicit reverse-mode backward pass for ViTClsSpec.forward.

                          Instances For