ResNet (spec model) #
Defines a small ResNet‑style architecture with residual/skip connections.
PyTorch analogy: this mirrors the high-level structure of torchvision.models.resnet*:
- convolution + normalization + ReLU,
- residual blocks with an identity/projection shortcut,
- global average pooling,
- a final linear classifier.
Important scope note:
- Residual blocks below keep stride fixed to
1, so spatial resolution is preserved acrosslayer1..layer4. (Standard ResNet down-samples at the start oflayer2..layer4; adding that is possible, but it complicates the type-level shape discipline.) - Channel changes can be handled via an optional projection shortcut (
shortcut_conv) plus optional shortcut batch-norm parameters. Builders such asResNetSpec.zeroInitgate this behindResNetConfig.useProjectionShortcuts(defaultfalse). - If
shortcut_conv = none, we still define a total shortcut by falling back to identity / channel padding / channel slicing.
Torchvision compatibility note:
- In standard ResNet (as implemented in
torchvision.models.resnet*), whenever input/output channels differ, the shortcut path is a learnable 1×1 projection (typically with BatchNorm). - In TorchLean, that corresponds to enabling
cfg.useProjectionShortcuts = truewhen constructing aResNetSpecvia helpers likeResNetSpec.zeroInit, or otherwise supplyingshortcut_conv := somein the block parameters.
This is a spec model: operations are written in terms of Spec.Tensor and layer specs from
NN/Spec/Layers/*.
ResNet architectural hyperparameters (simplified, spec-layer).
PyTorch mental model:
torchvision.models.resnet.ResNetfixes a few "stem" choices (7×7 conv, stride 2, etc.) and varies the per-stage widths and block counts based on a small config.
TorchLean’s spec ResNet is intentionally smaller in scope:
- blocks keep
stride=1so spatial resolution stays constant insidelayer1..layer4, - we still expose the stem / stage widths / stage block counts as explicit configuration so the model definition does not hide numeric architecture choices in its types.
- stemOutChannels : ℕ
Output channels of the initial conv stem (typical:
64). - stemKernel : ℕ
Kernel size of the initial conv stem (typical:
7). - stemStride : ℕ
Stride of the initial conv stem (typical:
2). - stemPadding : ℕ
Symmetric padding of the initial conv stem (typical:
3). - poolKernel : ℕ
MaxPool kernel size (typical:
3). - poolStride : ℕ
MaxPool stride (typical:
2). - stage1OutChannels : ℕ
Stage 1 output channels (typical:
64). - stage1Blocks : ℕ
Stage 1 block count (typical:
2for ResNet-18). - stage2OutChannels : ℕ
Stage 2 output channels (typical:
128). - stage2Blocks : ℕ
Stage 2 block count (typical:
2for ResNet-18). - stage3OutChannels : ℕ
Stage 3 output channels (typical:
256). - stage3Blocks : ℕ
Stage 3 block count (typical:
2for ResNet-18). - stage4OutChannels : ℕ
Stage 4 output channels (typical:
512). - stage4Blocks : ℕ
Stage 4 block count (typical:
2for ResNet-18). - useProjectionShortcuts : Bool
If
true, the first block in each stage whose channel count changes will use a learned1×1projection shortcut (and optional shortcut BN params) when constructed by helpers likeResNetSpec.zeroInit.If
false(default), we omit the projection parameters and rely on the total fallback shortcut used byResNetBlockSpec.forwardwhenshortcut_conv = none: identity / channel padding / channel slicing.Note: this simplified Spec ResNet keeps the main-path stride fixed to
1, so the projection shortcut (when enabled) also usesstride = 1here.
Instances For
Well-formedness conditions for ResNetConfig.
We keep these separate from the data record so "PyTorch-like configs" stay ergonomic, while still letting the spec model use the nonzero facts needed by some layer specs.
Instances For
Torchvision-style ResNet-18 hyperparameters (for our simplified spec).
Instances For
resnet18Config satisfies the nonzero facts required by the spec layer.
resnet18Config, but with torchvision-style projection shortcuts enabled for stage transitions.
This keeps the same block counts and widths, but changes the default shortcut used by
ResNetSpec.zeroInit when channel counts change:
useProjectionShortcuts = false: pad/slice fallback whenshortcut_conv = none(default).useProjectionShortcuts = true: build a learned1×1shortcut conv (and shortcut BN params).
Instances For
Any Nat expression of the form n + 1 is strictly positive.
A basic residual block (two 3×3 convolutions, each followed by BatchNorm; ReLU after the first BN and after the residual addition).
This corresponds to the "basic block" used in ResNet-18/34.
PyTorch analogy (schematic):
y = relu( bn2(conv2( relu(bn1(conv1(x))) )) + shortcut(x) )
- conv1 : Conv2DSpec inChannels outChannels 3 3 1 1 α h1 _proof_1 _proof_1
conv 1.
- conv2 : Conv2DSpec outChannels outChannels 3 3 1 1 α h2 _proof_1 _proof_1
conv 2.
- bn1_gamma : Tensor α (Shape.dim outChannels Shape.scalar)
bn 1 gamma.
- bn1_beta : Tensor α (Shape.dim outChannels Shape.scalar)
bn 1 beta.
- bn2_gamma : Tensor α (Shape.dim outChannels Shape.scalar)
bn 2 gamma.
- bn2_beta : Tensor α (Shape.dim outChannels Shape.scalar)
bn 2 beta.
- shortcut_conv : Option (Conv2DSpec inChannels outChannels 1 1 1 0 α h1 _proof_2 _proof_2)
shortcut conv.
- shortcut_bn_gamma : Option (Tensor α (Shape.dim outChannels Shape.scalar))
shortcut bn gamma.
- shortcut_bn_beta : Option (Tensor α (Shape.dim outChannels Shape.scalar))
shortcut bn beta.
Instances For
Forward pass for a basic residual block.
Type-level note: with stride=1 and padding=1, a 3×3 convolution preserves H×W (assuming H,W>0),
so the block input and output share spatial dimensions.
Instances For
A layer is a "first" block that can change channels, followed by zero or more homogeneous blocks (same input/output channels).
We keep the rest blocks in a list rather than a fixed-length vector so that the definition stays
lightweight and easy to build in examples. The blockCount index is documentation: the list length
is the source of truth for "how many blocks are actually present".
- first : ResNetBlockSpec α inChannels outChannels h1 h2
- rest : List (ResNetBlockSpec α outChannels outChannels h2 h2)
Instances For
Forward pass for a ResNet layer: run the first block, then fold over the remaining blocks.
Instances For
Gradients (explicit reverse-mode) #
ResNet is a good example of why the spec layer carries explicit shape structure:
- Residual connections force us to be precise about shapes and casting, otherwise the definition no longer states exactly which tensors are being added.
- The backward pass follows the same discipline: split gradients across the residual branches, run each layer's backward rule, then add the contributions back together.
We keep things simple:
- stride is fixed to
1inside blocks (matching the forward spec above), - we recompute intermediates locally (no global tape),
- optional projection shortcuts (
shortcut_conv) are handled when present, and "fallback" shortcuts (pad/slice/identity) have explicit adjoints.
Parameter gradients for a basic residual block.
This mirrors the fields of ResNetBlockSpec, plus optional gradients for the optional projection
shortcut.
- d_conv1_kernel : Tensor α (Shape.dim outChannels (Shape.dim inChannels (Shape.dim 3 (Shape.dim 3 Shape.scalar))))
d conv 1 kernel.
- d_conv1_bias : Tensor α (Shape.dim outChannels Shape.scalar)
d conv 1 bias.
- d_conv2_kernel : Tensor α (Shape.dim outChannels (Shape.dim outChannels (Shape.dim 3 (Shape.dim 3 Shape.scalar))))
d conv 2 kernel.
- d_conv2_bias : Tensor α (Shape.dim outChannels Shape.scalar)
d conv 2 bias.
- d_bn1_gamma : Tensor α (Shape.dim outChannels Shape.scalar)
d bn 1 gamma.
- d_bn1_beta : Tensor α (Shape.dim outChannels Shape.scalar)
d bn 1 beta.
- d_bn2_gamma : Tensor α (Shape.dim outChannels Shape.scalar)
d bn 2 gamma.
- d_bn2_beta : Tensor α (Shape.dim outChannels Shape.scalar)
d bn 2 beta.
- d_shortcut_conv : Option (Tensor α (Shape.dim outChannels (Shape.dim inChannels (Shape.dim 1 (Shape.dim 1 Shape.scalar)))) × Tensor α (Shape.dim outChannels Shape.scalar))
d shortcut conv.
- d_shortcut_bn_gamma : Option (Tensor α (Shape.dim outChannels Shape.scalar))
d shortcut bn gamma.
- d_shortcut_bn_beta : Option (Tensor α (Shape.dim outChannels Shape.scalar))
d shortcut bn beta.
Instances For
Backward/VJP for a basic residual block.
High-level math:
- The block output is
relu(main(x) + shortcut(x)). - Backprop therefore:
- multiplies by
relu'at the output, - splits the upstream gradient across the
+into the main path and the shortcut path, - runs BN/conv adjoints on the main path,
- runs either the projection-conv adjoint or the identity/pad/slice adjoint on the shortcut,
- adds the resulting
dXcontributions.
- multiplies by
Instances For
Gradients for a ResNetLayerSpec: one gradient bundle per block.
- first : ResNetBlockGrads inChannels outChannels α
first.
- rest : List (ResNetBlockGrads outChannels outChannels α)
rest.
Instances For
Backward/VJP for a ResNet layer.
Implementation note: the rest blocks are a list, so we explicitly reconstruct the intermediate
inputs needed for each block's backward pass. This keeps the spec self-contained (no global tape).
Instances For
Instances For
Full ResNet-18-like specification.
Pipeline (schematic):
conv7x7/stride2 -> BN -> ReLU -> maxpool/stride2 -> layer1 -> layer2 -> layer3 -> layer4 -> global_avg_pool -> linear classifier.
PyTorch analogy: this matches the main stages of torchvision.models.resnet18, but recall the
"simplified stride=1 blocks" note from the file header.
- initial_conv : Conv2DSpec inputChannels cfg.stemOutChannels cfg.stemKernel cfg.stemKernel cfg.stemStride cfg.stemPadding α h1 ⋯ ⋯
initial conv.
- initial_bn_gamma : Tensor α (Shape.dim cfg.stemOutChannels Shape.scalar)
initial bn gamma.
- initial_bn_beta : Tensor α (Shape.dim cfg.stemOutChannels Shape.scalar)
initial bn beta.
- initial_pool : MaxPool2DSpec cfg.poolKernel cfg.poolKernel cfg.poolStride ⋯ ⋯ ⋯
initial pool.
- layer1 : ResNetLayerSpec α cfg.stemOutChannels cfg.stage1OutChannels cfg.stage1Blocks ⋯ ⋯
layer 1.
- layer2 : ResNetLayerSpec α cfg.stage1OutChannels cfg.stage2OutChannels cfg.stage2Blocks ⋯ ⋯
layer 2.
- layer3 : ResNetLayerSpec α cfg.stage2OutChannels cfg.stage3OutChannels cfg.stage3Blocks ⋯ ⋯
layer 3.
- layer4 : ResNetLayerSpec α cfg.stage3OutChannels cfg.stage4OutChannels cfg.stage4Blocks ⋯ ⋯
layer 4.
- classifier : LinearSpec α cfg.stage4OutChannels numClasses
classifier.
Instances For
Forward pass for the ResNet spec.
PyTorch analogy:
global_avg_pool2d_flat_speccorresponds toAdaptiveAvgPool2d((1,1))followed by flattening.linear_speccorresponds to the finalnn.Linear(cfg.stage4OutChannels, numClasses).
Instances For
Gradients for the full ResNetSpec forward pass (explicit reverse-mode).
- d_initial_kernel : Tensor α (Shape.dim cfg.stemOutChannels (Shape.dim inputChannels (Shape.dim cfg.stemKernel (Shape.dim cfg.stemKernel Shape.scalar))))
d initial kernel.
- d_initial_bias : Tensor α (Shape.dim cfg.stemOutChannels Shape.scalar)
d initial bias.
- d_initial_bn_gamma : Tensor α (Shape.dim cfg.stemOutChannels Shape.scalar)
d initial bn gamma.
- d_initial_bn_beta : Tensor α (Shape.dim cfg.stemOutChannels Shape.scalar)
d initial bn beta.
- d_layer1 : ResNetLayerGrads cfg.stemOutChannels cfg.stage1OutChannels α
d layer 1.
- d_layer2 : ResNetLayerGrads cfg.stage1OutChannels cfg.stage2OutChannels α
d layer 2.
- d_layer3 : ResNetLayerGrads cfg.stage2OutChannels cfg.stage3OutChannels α
d layer 3.
- d_layer4 : ResNetLayerGrads cfg.stage3OutChannels cfg.stage4OutChannels α
d layer 4.
- d_classifier_W : Tensor α (Shape.dim numClasses (Shape.dim cfg.stage4OutChannels Shape.scalar))
d classifier W.
- d_classifier_b : Tensor α (Shape.dim numClasses Shape.scalar)
d classifier b.
Instances For
Backward/VJP for the full ResNet spec.
This follows the same structure as the forward pass, but in reverse:
- classifier backward,
- global avg-pool backward,
layer4..layer1backward,- max-pool backward,
- initial ReLU backward,
- initial BN backward,
- initial conv backward.
Instances For
Construct a simplified ResNet-18 spec (zero/one initialization).
This is primarily a runnable/spec baseline and a shape-checking harness. It does not aim to match the trained torchvision weights or initialization schemes.
Instances For
ResNet-18 spec constructor (specialization of ResNetSpec.zeroInit).
By default this uses resnet18Config, whose useProjectionShortcuts = false, so any channel
changes are handled via the total fallback shortcut (pad/slice) when shortcut_conv = none.
If you want learned 1×1 projection shortcuts on channel changes, use
ResNet18SpecWithProjections.
Instances For
ResNet-18 spec constructor using learned 1×1 projection shortcuts on channel changes.
Instances For
Bottleneck blocks (forward-only) #
This section defines the bottleneck block used in ResNet-50/101/152.
This is a forward-only baseline. The stride field is included to match the usual API, but the
convolution specs here all use stride 1; wiring stride through the type-level conv shape
expressions is outside this spec baseline.
Bottleneck residual block spec (ResNet-50/101/152 style), forward-only.
The bottleneck block uses a 1x1 -> 3x3 -> 1x1 conv stack with BatchNorms and a residual shortcut.
The stride field is included for API shape, but this spec keeps stride fixed to 1 in
the conv specs (see module header for scope note).
- conv1 : Conv2DSpec inChannels (outChannels / 4) 1 1 1 0 α h1 ResNetBlockSpec._proof_2 ResNetBlockSpec._proof_2
conv 1.
- conv2 : Conv2DSpec (outChannels / 4) (outChannels / 4) 3 3 1 1 α h3 ResNetBlockSpec._proof_1 ResNetBlockSpec._proof_1
conv 2.
- conv3 : Conv2DSpec (outChannels / 4) outChannels 1 1 1 0 α h3 ResNetBlockSpec._proof_2 ResNetBlockSpec._proof_2
conv 3.
- bn1_gamma : Tensor α (Shape.dim (outChannels / 4) Shape.scalar)
bn 1 gamma.
- bn1_beta : Tensor α (Shape.dim (outChannels / 4) Shape.scalar)
bn 1 beta.
- bn2_gamma : Tensor α (Shape.dim (outChannels / 4) Shape.scalar)
bn 2 gamma.
- bn2_beta : Tensor α (Shape.dim (outChannels / 4) Shape.scalar)
bn 2 beta.
- bn3_gamma : Tensor α (Shape.dim outChannels Shape.scalar)
bn 3 gamma.
- bn3_beta : Tensor α (Shape.dim outChannels Shape.scalar)
bn 3 beta.
- shortcut_conv : Option (Conv2DSpec inChannels outChannels 1 1 1 0 α h1 ResNetBlockSpec._proof_2 ResNetBlockSpec._proof_2)
shortcut conv.
- shortcut_bn_gamma : Option (Tensor α (Shape.dim outChannels Shape.scalar))
shortcut bn gamma.
- shortcut_bn_beta : Option (Tensor α (Shape.dim outChannels Shape.scalar))
shortcut bn beta.
- stride : ℕ
Stride.
Instances For
Forward pass for a bottleneck residual block (forward-only baseline).
Instances For
Bottleneck ResNet layer spec: one "first" block plus a list of homogeneous "rest" blocks.
- first : BottleneckResNetBlockSpec α inChannels outChannels h1 h2 h3
- rest : List (BottleneckResNetBlockSpec α outChannels outChannels h2 h2 h3)
Instances For
Forward pass for a bottleneck ResNet layer (fold over the rest list after the first block).
Instances For
Compute the model depth metadata used by examples, counting the stem, each block, and the head.
Instances For
Compute a lightweight parameter-count estimate for model summaries.