TorchLean API

NN.Spec.Models.Resnet

ResNet (spec model) #

Defines a small ResNet‑style architecture with residual/skip connections.

PyTorch analogy: this mirrors the high-level structure of torchvision.models.resnet*:

Important scope note:

Torchvision compatibility note:

This is a spec model: operations are written in terms of Spec.Tensor and layer specs from NN/Spec/Layers/*.

ResNet architectural hyperparameters (simplified, spec-layer).

PyTorch mental model:

  • torchvision.models.resnet.ResNet fixes a few "stem" choices (7×7 conv, stride 2, etc.) and varies the per-stage widths and block counts based on a small config.

TorchLean’s spec ResNet is intentionally smaller in scope:

  • blocks keep stride=1 so spatial resolution stays constant inside layer1..layer4,
  • we still expose the stem / stage widths / stage block counts as explicit configuration so the model definition does not hide numeric architecture choices in its types.
  • stemOutChannels :

    Output channels of the initial conv stem (typical: 64).

  • stemKernel :

    Kernel size of the initial conv stem (typical: 7).

  • stemStride :

    Stride of the initial conv stem (typical: 2).

  • stemPadding :

    Symmetric padding of the initial conv stem (typical: 3).

  • poolKernel :

    MaxPool kernel size (typical: 3).

  • poolStride :

    MaxPool stride (typical: 2).

  • stage1OutChannels :

    Stage 1 output channels (typical: 64).

  • stage1Blocks :

    Stage 1 block count (typical: 2 for ResNet-18).

  • stage2OutChannels :

    Stage 2 output channels (typical: 128).

  • stage2Blocks :

    Stage 2 block count (typical: 2 for ResNet-18).

  • stage3OutChannels :

    Stage 3 output channels (typical: 256).

  • stage3Blocks :

    Stage 3 block count (typical: 2 for ResNet-18).

  • stage4OutChannels :

    Stage 4 output channels (typical: 512).

  • stage4Blocks :

    Stage 4 block count (typical: 2 for ResNet-18).

  • useProjectionShortcuts : Bool

    If true, the first block in each stage whose channel count changes will use a learned 1×1 projection shortcut (and optional shortcut BN params) when constructed by helpers like ResNetSpec.zeroInit.

    If false (default), we omit the projection parameters and rely on the total fallback shortcut used by ResNetBlockSpec.forward when shortcut_conv = none: identity / channel padding / channel slicing.

    Note: this simplified Spec ResNet keeps the main-path stride fixed to 1, so the projection shortcut (when enabled) also uses stride = 1 here.

Instances For

    Well-formedness conditions for ResNetConfig.

    We keep these separate from the data record so "PyTorch-like configs" stay ergonomic, while still letting the spec model use the nonzero facts needed by some layer specs.

    Instances For

      Torchvision-style ResNet-18 hyperparameters (for our simplified spec).

      Instances For

        resnet18Config satisfies the nonzero facts required by the spec layer.

        resnet18Config, but with torchvision-style projection shortcuts enabled for stage transitions.

        This keeps the same block counts and widths, but changes the default shortcut used by ResNetSpec.zeroInit when channel counts change:

        Instances For
          theorem Spec.conv3x3_outSize_eq (n : ) (hn : n 0) :
          (n + 2 * 1 - 3) / 1 + 1 = n

          Output-size identity for a 3×3 conv with stride=1 and padding=1 (Nat-level formula).

          theorem Spec.conv1x1_outSize_eq (n : ) (hn : n 0) :
          (n + 2 * 0 - 1) / 1 + 1 = n

          Output-size identity for a 1×1 conv with stride=1 and padding=0 (Nat-level formula).

          theorem Spec.pos_add_one (n : ) :
          0 < n + 1

          Any Nat expression of the form n + 1 is strictly positive.

          theorem Spec.ne_zero_add_one (n : ) :
          n + 1 0

          Any Nat expression of the form n + 1 is not 0.

          structure Spec.ResNetBlockSpec (α : Type) (inChannels outChannels : ) (h1 : inChannels 0) (h2 : outChannels 0) :

          A basic residual block (two 3×3 convolutions, each followed by BatchNorm; ReLU after the first BN and after the residual addition).

          This corresponds to the "basic block" used in ResNet-18/34.

          PyTorch analogy (schematic):

          y = relu( bn2(conv2( relu(bn1(conv1(x))) )) + shortcut(x) )

          Instances For
            def Spec.ResNetBlockSpec.forward {α : Type} [Context α] {inChannels outChannels inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (block : ResNetBlockSpec α inChannels outChannels h1 h2) (x : MultiChannelImage inChannels inH inW α) (h3 : inH 0) (h4 : inW 0) :
            MultiChannelImage outChannels inH inW α

            Forward pass for a basic residual block.

            Type-level note: with stride=1 and padding=1, a 3×3 convolution preserves H×W (assuming H,W>0), so the block input and output share spatial dimensions.

            Instances For
              structure Spec.ResNetLayerSpec (α : Type) (inChannels outChannels blockCount : ) (h1 : inChannels 0) (h2 : outChannels 0) :

              A layer is a "first" block that can change channels, followed by zero or more homogeneous blocks (same input/output channels).

              We keep the rest blocks in a list rather than a fixed-length vector so that the definition stays lightweight and easy to build in examples. The blockCount index is documentation: the list length is the source of truth for "how many blocks are actually present".

              Instances For
                def Spec.ResNetLayerSpec.forward {α : Type} [Context α] {inChannels outChannels blockCount inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (layer : ResNetLayerSpec α inChannels outChannels blockCount h1 h2) (x : MultiChannelImage inChannels inH inW α) (h3 : inH 0) (h4 : inW 0) :
                MultiChannelImage outChannels inH inW α

                Forward pass for a ResNet layer: run the first block, then fold over the remaining blocks.

                Instances For

                  Gradients (explicit reverse-mode) #

                  ResNet is a good example of why the spec layer carries explicit shape structure:

                  We keep things simple:

                  structure Spec.ResNetBlockGrads (inChannels outChannels : ) (α : Type) :

                  Parameter gradients for a basic residual block.

                  This mirrors the fields of ResNetBlockSpec, plus optional gradients for the optional projection shortcut.

                  Instances For
                    def Spec.ResNetBlockSpec.backward {α : Type} [Context α] {inChannels outChannels inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (block : ResNetBlockSpec α inChannels outChannels h1 h2) (x : MultiChannelImage inChannels inH inW α) (grad_output : MultiChannelImage outChannels inH inW α) (h3 : inH 0) (h4 : inW 0) :
                    ResNetBlockGrads inChannels outChannels α × MultiChannelImage inChannels inH inW α

                    Backward/VJP for a basic residual block.

                    High-level math:

                    • The block output is relu(main(x) + shortcut(x)).
                    • Backprop therefore:
                      1. multiplies by relu' at the output,
                      2. splits the upstream gradient across the + into the main path and the shortcut path,
                      3. runs BN/conv adjoints on the main path,
                      4. runs either the projection-conv adjoint or the identity/pad/slice adjoint on the shortcut,
                      5. adds the resulting dX contributions.
                    Instances For
                      structure Spec.ResNetLayerGrads (inChannels outChannels : ) (α : Type) :

                      Gradients for a ResNetLayerSpec: one gradient bundle per block.

                      Instances For
                        def Spec.ResNetLayerSpec.backward {α : Type} [Context α] {inChannels outChannels blockCount inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (layer : ResNetLayerSpec α inChannels outChannels blockCount h1 h2) (x : MultiChannelImage inChannels inH inW α) (grad_output : MultiChannelImage outChannels inH inW α) (h3 : inH 0) (h4 : inW 0) :
                        ResNetLayerGrads inChannels outChannels α × MultiChannelImage inChannels inH inW α

                        Backward/VJP for a ResNet layer.

                        Implementation note: the rest blocks are a list, so we explicitly reconstruct the intermediate inputs needed for each block's backward pass. This keeps the spec self-contained (no global tape).

                        Instances For
                          def Spec.ResNetLayerSpec.backward.collect_inputs {α : Type} [Context α] {outChannels inH inW : } (h2 : outChannels 0) (h3 : inH 0) (h4 : inW 0) (blocks : List (ResNetBlockSpec α outChannels outChannels h2 h2)) (cur : MultiChannelImage outChannels inH inW α) :
                          List (MultiChannelImage outChannels inH inW α)
                          Instances For
                            structure Spec.ResNetSpec (cfg : ResNetConfig) (α : Type) (inputChannels numClasses : ) (h1 : inputChannels 0) (h2 : numClasses 0) (hCfg : cfg.WF) :

                            Full ResNet-18-like specification.

                            Pipeline (schematic):

                            conv7x7/stride2 -> BN -> ReLU -> maxpool/stride2 -> layer1 -> layer2 -> layer3 -> layer4 -> global_avg_pool -> linear classifier.

                            PyTorch analogy: this matches the main stages of torchvision.models.resnet18, but recall the "simplified stride=1 blocks" note from the file header.

                            Instances For
                              def Spec.ResNetSpec.forward {α : Type} [Context α] {cfg : ResNetConfig} {inputChannels numClasses inH inW : } (h1 : inputChannels 0) (h2 : numClasses 0) (hCfg : cfg.WF) (resnet : ResNetSpec cfg α inputChannels numClasses h1 h2 hCfg) (x : MultiChannelImage inputChannels inH inW α) (_h3 : inH 0) (_h4 : inW 0) :
                              Tensor α (Shape.dim numClasses Shape.scalar)

                              Forward pass for the ResNet spec.

                              PyTorch analogy:

                              • global_avg_pool2d_flat_spec corresponds to AdaptiveAvgPool2d((1,1)) followed by flattening.
                              • linear_spec corresponds to the final nn.Linear(cfg.stage4OutChannels, numClasses).
                              Instances For
                                structure Spec.ResNetGrads (cfg : ResNetConfig) (inputChannels numClasses : ) (α : Type) :

                                Gradients for the full ResNetSpec forward pass (explicit reverse-mode).

                                Instances For
                                  def Spec.ResNetSpec.backward {α : Type} [Context α] {cfg : ResNetConfig} {inputChannels numClasses inH inW : } (h1 : inputChannels 0) (h2 : numClasses 0) (hCfg : cfg.WF) (resnet : ResNetSpec cfg α inputChannels numClasses h1 h2 hCfg) (x : MultiChannelImage inputChannels inH inW α) (grad_output : Tensor α (Shape.dim numClasses Shape.scalar)) (_h3 : inH 0) (_h4 : inW 0) :
                                  ResNetGrads cfg inputChannels numClasses α × MultiChannelImage inputChannels inH inW α

                                  Backward/VJP for the full ResNet spec.

                                  This follows the same structure as the forward pass, but in reverse:

                                  1. classifier backward,
                                  2. global avg-pool backward,
                                  3. layer4..layer1 backward,
                                  4. max-pool backward,
                                  5. initial ReLU backward,
                                  6. initial BN backward,
                                  7. initial conv backward.
                                  Instances For
                                    def Spec.ResNetSpec.zeroInit (cfg : ResNetConfig) (hCfg : cfg.WF) (α : Type) [Context α] (inputChannels numClasses : ) (h1 : inputChannels 0) (h2 : numClasses 0) :
                                    ResNetSpec cfg α inputChannels numClasses h1 h2 hCfg

                                    Construct a simplified ResNet-18 spec (zero/one initialization).

                                    This is primarily a runnable/spec baseline and a shape-checking harness. It does not aim to match the trained torchvision weights or initialization schemes.

                                    Instances For
                                      def Spec.ResNet18Spec (α : Type) [Context α] (inputChannels numClasses : ) (h1 : inputChannels 0) (h2 : numClasses 0) :
                                      ResNetSpec resnet18Config α inputChannels numClasses h1 h2 resnet18Config_wf

                                      ResNet-18 spec constructor (specialization of ResNetSpec.zeroInit).

                                      By default this uses resnet18Config, whose useProjectionShortcuts = false, so any channel changes are handled via the total fallback shortcut (pad/slice) when shortcut_conv = none.

                                      If you want learned 1×1 projection shortcuts on channel changes, use ResNet18SpecWithProjections.

                                      Instances For
                                        def Spec.ResNet18SpecWithProjections (α : Type) [Context α] (inputChannels numClasses : ) (h1 : inputChannels 0) (h2 : numClasses 0) :

                                        ResNet-18 spec constructor using learned 1×1 projection shortcuts on channel changes.

                                        Instances For

                                          Bottleneck blocks (forward-only) #

                                          This section defines the bottleneck block used in ResNet-50/101/152.

                                          This is a forward-only baseline. The stride field is included to match the usual API, but the convolution specs here all use stride 1; wiring stride through the type-level conv shape expressions is outside this spec baseline.

                                          structure Spec.BottleneckResNetBlockSpec (α : Type) (inChannels outChannels : ) (h1 : inChannels 0) (h2 : outChannels 0) (h3 : outChannels / 4 0) :

                                          Bottleneck residual block spec (ResNet-50/101/152 style), forward-only.

                                          The bottleneck block uses a 1x1 -> 3x3 -> 1x1 conv stack with BatchNorms and a residual shortcut. The stride field is included for API shape, but this spec keeps stride fixed to 1 in the conv specs (see module header for scope note).

                                          Instances For
                                            def Spec.BottleneckResNetBlockSpec.forward {α : Type} [Context α] {inChannels outChannels inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (h3 : outChannels / 4 0) (block : BottleneckResNetBlockSpec α inChannels outChannels h1 h2 h3) (x : MultiChannelImage inChannels inH inW α) (h4 : inH 0) (h5 : inW 0) :
                                            MultiChannelImage outChannels inH inW α

                                            Forward pass for a bottleneck residual block (forward-only baseline).

                                            Instances For
                                              structure Spec.BottleneckResNetLayerSpec (α : Type) (inChannels outChannels blockCount : ) (h1 : inChannels 0) (h2 : outChannels 0) (h3 : outChannels / 4 0) :

                                              Bottleneck ResNet layer spec: one "first" block plus a list of homogeneous "rest" blocks.

                                              Instances For
                                                def Spec.BottleneckResNetLayerSpec.forward {α : Type} [Context α] {inChannels outChannels blockCount inH inW : } (h1 : inChannels 0) (h2 : outChannels 0) (h3 : outChannels / 4 0) (layer : BottleneckResNetLayerSpec α inChannels outChannels blockCount h1 h2 h3) (x : MultiChannelImage inChannels inH inW α) (h4 : inH 0) (h5 : inW 0) :
                                                MultiChannelImage outChannels inH inW α

                                                Forward pass for a bottleneck ResNet layer (fold over the rest list after the first block).

                                                Instances For
                                                  def Spec.ResNetSpec.depth {α : Type} {cfg : ResNetConfig} {inputChannels numClasses : } (h1 : inputChannels 0) (h2 : numClasses 0) (hCfg : cfg.WF) (resnet : ResNetSpec cfg α inputChannels numClasses h1 h2 hCfg) :

                                                  Compute the model depth metadata used by examples, counting the stem, each block, and the head.

                                                  Instances For
                                                    def Spec.ResNetSpec.parameterCount {α : Type} {cfg : ResNetConfig} {inputChannels numClasses : } (h1 : inputChannels 0) (h2 : numClasses 0) (_hCfg : cfg.WF) (_resnet : ResNetSpec cfg α inputChannels numClasses h1 h2 _hCfg) :

                                                    Compute a lightweight parameter-count estimate for model summaries.

                                                    Instances For