Unet #
U-Net (2-level) model.
This file defines a small U-Net style architecture (a single downsample + upsample):
- down path: two
Conv2d(3x3, stride=1, padding=1) + ReLUblocks, - downsample:
MaxPool2d(kernel=2, stride=2), - bottleneck: two more conv blocks,
- upsample:
ConvTranspose2d(kernel=2, stride=2), - skip connection: concatenate channels and run two conv blocks,
- output head:
Conv2d(1x1)to mapbaseC -> outC.
PyTorch mental model:
- this matches the common "U-Net block diagram" but written without a batch axis, so our tensor
convention is
(C,H,W)rather than(N,C,H,W); - the skip connection concatenates on the channel axis (in PyTorch with a batch axis that would be
torch.cat([skip, up], dim=1); here it isconcat_dim0_specbecause channels are axis0).
Shape notes:
- the 3x3 conv blocks are set up to preserve
H×W(stride=1, padding=1), - the pool/upsample pair is the usual
2xdown then2xup, but for odd spatial sizes theConvTranspose2dformula can produce an off-by-one; we surface this as explicit equalities (h_upH,h_upW) so the caller can pick compatibleinH,inW(typically even).
References:
- Ronneberger et al., "U-Net: Convolutional Networks for Biomedical Image Segmentation" (MICCAI 2015).
PyTorch docs (for API intuition, not semantics):
torch.nn.Conv2d: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.htmltorch.nn.MaxPool2d: https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.htmltorch.nn.ConvTranspose2d: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Configuration #
Architectural hyperparameters live in a dedicated config record.
PyTorch mental model:
- this mirrors the way you would pass
kernel_size/stride/paddingtonn.Conv2d,nn.MaxPool2d, andnn.ConvTranspose2d, plus the base channel width.
U-Net (2-level) architectural hyperparameters (spec layer).
- poolKernel : ℕ
kernel_sizefor the max-pool layer (typical:2). - poolStride : ℕ
stridefor the max-pool layer (typical:2). - convKernel : ℕ
kernel_sizefor the 2D conv blocks (typical:3). - convStride : ℕ
stridefor the 2D conv blocks (typical:1). - convPadding : ℕ
symmetric zero
paddingfor the 2D conv blocks (typical:1). - upKernel : ℕ
kernel_sizefor the transposed-convolution upsampler (typical:2). - upStride : ℕ
stridefor the transposed-convolution upsampler (typical:2). - upPadding : ℕ
paddingfor the transposed-convolution upsampler (typical:0). - headKernel : ℕ
kernel_sizefor the final output head conv (typical:1). - headStride : ℕ
stridefor the final output head conv (typical:1). - headPadding : ℕ
paddingfor the final output head conv (typical:0). - baseC : ℕ
Base channel count (typical:
64).
Instances For
Well-formedness conditions for UNet2Config (the few nonzero facts needed by layer specs).
Instances For
Canonical "classic U-Net-ish" defaults for our 2-level spec.
Instances For
unet2DefaultConfig satisfies the nonzero facts required by the spec layer.
Output height after MaxPool2d(kernel=2, stride=2) (no padding).
Instances For
Output width after MaxPool2d(kernel=2, stride=2) (no padding).
Instances For
Output height after MaxPool2d(2,2) then ConvTranspose2d(2,2) (with padding=0).
Instances For
Output width after MaxPool2d(2,2) then ConvTranspose2d(2,2) (with padding=0).
Instances For
2-level U-Net parameter record (spec).
This is a compact U-Net with one downsample and one upsample step:
- two conv + ReLU blocks at full resolution (with a skip),
- max-pooling, then two conv + ReLU blocks at the lower resolution,
- a transposed-conv upsampler,
- channel concatenation with the skip feature map,
- two more conv + ReLU blocks,
- a final
1×1conv head.
Shape convention: tensors are (C,H,W) (no batch axis).
PyTorch analogue: a small U-Net built from nn.Conv2d, nn.MaxPool2d, nn.ConvTranspose2d,
and torch.cat along the channel axis.
- down1_1 : Spec.Conv2DSpec inC cfg.baseC cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α h_inC ⋯ ⋯
First 3×3 conv in the first down block (
inC -> baseC). - down1_2 : Spec.Conv2DSpec cfg.baseC cfg.baseC cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α ⋯ ⋯ ⋯
- down2_1 : Spec.Conv2DSpec cfg.baseC (2 * cfg.baseC) cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α ⋯ ⋯ ⋯
First 3×3 conv in the bottleneck block (
baseC -> 2*baseC). - down2_2 : Spec.Conv2DSpec (2 * cfg.baseC) (2 * cfg.baseC) cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α ⋯ ⋯ ⋯
Second 3×3 conv in the bottleneck block (
2*baseC -> 2*baseC). - upT : Spec.ConvTranspose2DSpec (2 * cfg.baseC) cfg.baseC cfg.upKernel cfg.upKernel cfg.upStride cfg.upPadding α ⋯ ⋯ ⋯
Transposed-convolution upsampler (
2*baseC -> baseC,kernel=2,stride=2). - up1_1 : Spec.Conv2DSpec (cfg.baseC + cfg.baseC) cfg.baseC cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α ⋯ ⋯ ⋯
First 3×3 conv after skip concatenation (
(baseC+baseC) -> baseC). - up1_2 : Spec.Conv2DSpec cfg.baseC cfg.baseC cfg.convKernel cfg.convKernel cfg.convStride cfg.convPadding α ⋯ ⋯ ⋯
- out1x1 : Spec.Conv2DSpec cfg.baseC outC cfg.headKernel cfg.headKernel cfg.headStride cfg.headPadding α ⋯ ⋯ ⋯
Final 1×1 conv head (
baseC -> outC).
Instances For
Gradients #
This U-Net is small enough that we can write a fully explicit backward pass in a "mirror the forward" style: rebuild the same intermediates, then walk back through them using the existing layer-level backward specs.
Key details:
concat_dim0_specis split viaconcat_dim0_backward_spec,- pooling backward uses
max_pool2d_multi_backward_spec, - ReLU is handled via elementwise gating
dZ = dY ⊙ ReLU'(Z).
PyTorch analogy:
- each
conv2d_backward_speccall corresponds to the gradients PyTorch computes forConv2d(weight,bias); max_pool2d_multi_backward_speccorresponds to max-pool backward using the argmax locations from the forward (our spec computes it from the inputs).
Parameter-gradient container for UNet2Spec.
This mirrors the parameter layout of UNet2Spec, recording kernel and bias gradients for each
convolution and transposed-convolution layer.
- d_down1_1_kernel : Spec.Tensor α (Spec.Shape.dim cfg.baseC (Spec.Shape.dim inC (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d down 1 1 kernel.
- d_down1_1_bias : Spec.Tensor α (Spec.Shape.dim cfg.baseC Spec.Shape.scalar)
d down 1 1 bias.
- d_down1_2_kernel : Spec.Tensor α (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d down 1 2 kernel.
- d_down1_2_bias : Spec.Tensor α (Spec.Shape.dim cfg.baseC Spec.Shape.scalar)
d down 1 2 bias.
- d_down2_1_kernel : Spec.Tensor α (Spec.Shape.dim (2 * cfg.baseC) (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d down 2 1 kernel.
- d_down2_1_bias : Spec.Tensor α (Spec.Shape.dim (2 * cfg.baseC) Spec.Shape.scalar)
d down 2 1 bias.
- d_down2_2_kernel : Spec.Tensor α (Spec.Shape.dim (2 * cfg.baseC) (Spec.Shape.dim (2 * cfg.baseC) (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d down 2 2 kernel.
- d_down2_2_bias : Spec.Tensor α (Spec.Shape.dim (2 * cfg.baseC) Spec.Shape.scalar)
d down 2 2 bias.
d up T kernel.
- d_upT_bias : Spec.Tensor α (Spec.Shape.dim cfg.baseC Spec.Shape.scalar)
d up T bias.
- d_up1_1_kernel : Spec.Tensor α (Spec.Shape.dim cfg.baseC (Spec.Shape.dim (cfg.baseC + cfg.baseC) (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d up 1 1 kernel.
- d_up1_1_bias : Spec.Tensor α (Spec.Shape.dim cfg.baseC Spec.Shape.scalar)
d up 1 1 bias.
- d_up1_2_kernel : Spec.Tensor α (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.convKernel (Spec.Shape.dim cfg.convKernel Spec.Shape.scalar))))
d up 1 2 kernel.
- d_up1_2_bias : Spec.Tensor α (Spec.Shape.dim cfg.baseC Spec.Shape.scalar)
d up 1 2 bias.
- d_out1x1_kernel : Spec.Tensor α (Spec.Shape.dim outC (Spec.Shape.dim cfg.baseC (Spec.Shape.dim cfg.headKernel (Spec.Shape.dim cfg.headKernel Spec.Shape.scalar))))
d out 1 x 1 kernel.
- d_out1x1_bias : Spec.Tensor α (Spec.Shape.dim outC Spec.Shape.scalar)
d out 1 x 1 bias.
Instances For
Forward pass for UNet2Spec.
Inputs/outputs use MultiChannelImage tensors of shape (C,H,W) (no batch axis).
The many h_* equalities are shape-rewrite hints: layer specs compute output sizes using explicit
arithmetic (matching PyTorch's formulas), and these equalities let callers assert "this 3×3 conv
preserves spatial size" or "pool then upsample returns to the original size" for a particular
choice of inH,inW (typically even).
Instances For
Backward pass for UNet2Spec.forward.
Given:
- the model parameters
m, - the forward input image
x, - an upstream gradient
grad_output = dL/dy, returns: - parameter gradients (
UNet2Grads), and - the gradient w.r.t. the input image (
dL/dx).
Implementation note: this is an explicit "recompute intermediates then walk backward" spec (no mutable tape), mirroring the math behind PyTorch autograd and standard conv/pool backward rules.