TorchLean API

NN.Spec.Models.Unet

Unet #

U-Net (2-level) model.

This file defines a small U-Net style architecture (a single downsample + upsample):

PyTorch mental model:

Shape notes:

References:

PyTorch docs (for API intuition, not semantics):

Configuration #

Architectural hyperparameters live in a dedicated config record.

PyTorch mental model:

U-Net (2-level) architectural hyperparameters (spec layer).

  • poolKernel :

    kernel_size for the max-pool layer (typical: 2).

  • poolStride :

    stride for the max-pool layer (typical: 2).

  • convKernel :

    kernel_size for the 2D conv blocks (typical: 3).

  • convStride :

    stride for the 2D conv blocks (typical: 1).

  • convPadding :

    symmetric zero padding for the 2D conv blocks (typical: 1).

  • upKernel :

    kernel_size for the transposed-convolution upsampler (typical: 2).

  • upStride :

    stride for the transposed-convolution upsampler (typical: 2).

  • upPadding :

    padding for the transposed-convolution upsampler (typical: 0).

  • headKernel :

    kernel_size for the final output head conv (typical: 1).

  • headStride :

    stride for the final output head conv (typical: 1).

  • headPadding :

    padding for the final output head conv (typical: 0).

  • baseC :

    Base channel count (typical: 64).

Instances For

    Well-formedness conditions for UNet2Config (the few nonzero facts needed by layer specs).

    Instances For

      Canonical "classic U-Net-ish" defaults for our 2-level spec.

      Instances For

        unet2DefaultConfig satisfies the nonzero facts required by the spec layer.

        @[reducible, inline]
        abbrev Models.UNetDownH (cfg : UNet2Config) (inH : ) :

        Output height after MaxPool2d(kernel=2, stride=2) (no padding).

        Instances For
          @[reducible, inline]
          abbrev Models.UNetDownW (cfg : UNet2Config) (inW : ) :

          Output width after MaxPool2d(kernel=2, stride=2) (no padding).

          Instances For
            @[reducible, inline]
            abbrev Models.UNetUpH (cfg : UNet2Config) (inH : ) :

            Output height after MaxPool2d(2,2) then ConvTranspose2d(2,2) (with padding=0).

            Instances For
              @[reducible, inline]
              abbrev Models.UNetUpW (cfg : UNet2Config) (inW : ) :

              Output width after MaxPool2d(2,2) then ConvTranspose2d(2,2) (with padding=0).

              Instances For
                structure Models.UNet2Spec (cfg : UNet2Config) (inC outC inH inW : ) (α : Type) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (h_inC : inC 0) (hCfg : cfg.WF) :

                2-level U-Net parameter record (spec).

                This is a compact U-Net with one downsample and one upsample step:

                • two conv + ReLU blocks at full resolution (with a skip),
                • max-pooling, then two conv + ReLU blocks at the lower resolution,
                • a transposed-conv upsampler,
                • channel concatenation with the skip feature map,
                • two more conv + ReLU blocks,
                • a final 1×1 conv head.

                Shape convention: tensors are (C,H,W) (no batch axis).

                PyTorch analogue: a small U-Net built from nn.Conv2d, nn.MaxPool2d, nn.ConvTranspose2d, and torch.cat along the channel axis.

                Instances For

                  Gradients #

                  This U-Net is small enough that we can write a fully explicit backward pass in a "mirror the forward" style: rebuild the same intermediates, then walk back through them using the existing layer-level backward specs.

                  Key details:

                  PyTorch analogy:

                  structure Models.UNet2Grads (cfg : UNet2Config) (inC outC inH inW : ) (α : Type) :

                  Parameter-gradient container for UNet2Spec.

                  This mirrors the parameter layout of UNet2Spec, recording kernel and bias gradients for each convolution and transposed-convolution layer.

                  Instances For
                    def Models.UNet2Spec.forward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : UNet2Config} {inC outC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : UNet2Spec cfg inC outC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) (h_convH : (inH + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = inH) (h_convW : (inW + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = inW) (h_convH_down : (UNetDownH cfg inH + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = UNetDownH cfg inH) (h_convW_down : (UNetDownW cfg inW + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = UNetDownW cfg inW) (h_upH : UNetUpH cfg inH = inH) (h_upW : UNetUpW cfg inW = inW) (h_outH : (inH + 2 * cfg.headPadding - cfg.headKernel) / cfg.headStride + 1 = inH) (h_outW : (inW + 2 * cfg.headPadding - cfg.headKernel) / cfg.headStride + 1 = inW) :
                    Spec.MultiChannelImage outC inH inW α

                    Forward pass for UNet2Spec.

                    Inputs/outputs use MultiChannelImage tensors of shape (C,H,W) (no batch axis).

                    The many h_* equalities are shape-rewrite hints: layer specs compute output sizes using explicit arithmetic (matching PyTorch's formulas), and these equalities let callers assert "this 3×3 conv preserves spatial size" or "pool then upsample returns to the original size" for a particular choice of inH,inW (typically even).

                    Instances For
                      def Models.UNet2Spec.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {cfg : UNet2Config} {inC outC inH inW : } {h_inC : inC 0} {hCfg : cfg.WF} (m : UNet2Spec cfg inC outC inH inW α h_inC hCfg) (x : Spec.MultiChannelImage inC inH inW α) (grad_output : Spec.MultiChannelImage outC inH inW α) (h_convH : (inH + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = inH) (h_convW : (inW + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = inW) (h_convH_down : (UNetDownH cfg inH + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = UNetDownH cfg inH) (h_convW_down : (UNetDownW cfg inW + 2 * cfg.convPadding - cfg.convKernel) / cfg.convStride + 1 = UNetDownW cfg inW) (h_upH : UNetUpH cfg inH = inH) (h_upW : UNetUpW cfg inW = inW) (h_outH : (inH + 2 * cfg.headPadding - cfg.headKernel) / cfg.headStride + 1 = inH) (h_outW : (inW + 2 * cfg.headPadding - cfg.headKernel) / cfg.headStride + 1 = inW) :
                      UNet2Grads cfg inC outC inH inW α × Spec.MultiChannelImage inC inH inW α

                      Backward pass for UNet2Spec.forward.

                      Given:

                      • the model parameters m,
                      • the forward input image x,
                      • an upstream gradient grad_output = dL/dy, returns:
                      • parameter gradients (UNet2Grads), and
                      • the gradient w.r.t. the input image (dL/dx).

                      Implementation note: this is an explicit "recompute intermediates then walk backward" spec (no mutable tape), mirroring the math behind PyTorch autograd and standard conv/pool backward rules.

                      Instances For