TorchLean API

NN.Spec.Layers.Pooling.TwoD

2D Pooling #

Single-channel and channels-first 2D max, average, adaptive, and smooth-max pooling specs.

structure Spec.MaxPool2DSpec (kH kW stride : ) (h1 : kH 0) (h2 : kW 0) (hStride : stride 0) :

MaxPool2d configuration.

The spec uses a fixed kernel (kH,kW) and a single stride value (applied to both height and width). We require kH ≠ 0, kW ≠ 0, and stride ≠ 0 so windows are nonempty and the output-shape arithmetic is well-defined.

PyTorch analogy: F.max_pool2d(x, kernel_size=(kH,kW), stride=stride).

  • kernelHeight :

    kernel Height.

  • kernelWidth :

    kernel Width.

  • stride :

    Stride.

Instances For
    structure Spec.AvgPool2DSpec (kH kW stride : ) (h1 : kH 0) (h2 : kW 0) (hStride : stride 0) :

    AvgPool2d configuration.

    We treat the pooling window as kH*kW elements and divide by that count. This corresponds to PyTorch's default behavior when no padding is present.

    • kernelHeight :

      kernel Height.

    • kernelWidth :

      kernel Width.

    • stride :

      Stride.

    Instances For
      def Spec.pool2dOutShape (inH inW kH kW stride : ) :

      Output shape for a 2D pooling op (single-channel) with no padding.

      This uses the standard "valid" pooling formula:

      outH = floor((inH - kH)/stride) + 1, outW = floor((inW - kW)/stride) + 1.

      PyTorch analogy: ceil_mode=false with no padding.

      Instances For
        def Spec.pool2dMultiOutShape (inC inH inW kH kW stride : ) :

        Output shape for multi-channel 2D pooling (channels preserved).

        Instances For
          def Spec.pool2dOutShapePad (inH inW kH kW stride padding : ) :

          Output shape for a 2D pooling op (single-channel) with symmetric padding.

          padding means we use the usual PyTorch output-size formula for an input extended by padding cells on each side. Hard max-pooling ignores padded cells (the PyTorch -∞ convention), while average-pooling below explicitly includes padded zeros.

          Instances For
            def Spec.pool2dMultiOutShapePad (inC inH inW kH kW stride padding : ) :

            Output shape for multi-channel 2D pooling with symmetric padding (channels preserved).

            Instances For

              Smooth max pooling #

              max_pool2d_spec uses max and is non-differentiable (ties and kink points).

              For proofs that need everywhere differentiability, we provide a smooth surrogate based on log-sum-exp over each pooling window:

              smooth_max(x₁,…,xₙ) = (1 / β) * log (∑ exp (β * xᵢ))

              This is the standard log-sum-exp surrogate and is intended for β ≠ 0.

              def Spec.smoothMaxPool2dSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (beta : α) (input : Image inH inW α) :
              Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

              Smooth max-pooling (single-channel) using a log-sum-exp surrogate.

              This is useful in proof settings that want a differentiable alternative to max_pool2d_spec. For large beta, the output approaches hard max pooling.

              Instances For
                def Spec.smoothMaxPool2dMultiSpec {α : Type} [Context α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (beta : α) (input : MultiChannelImage inC inH inW α) :
                MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                Smooth max-pooling (multi-channel): apply smooth_max_pool2d_spec per channel.

                Instances For
                  def Spec.smoothMaxPool2dJvpSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (beta : α) (input tangent : Image inH inW α) :
                  Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                  Forward-mode JVP for smooth max-pooling (single-channel).

                  For each pooling window this is the differential of the log-sum-exp surrogate, Σᵢ softmax(beta*xᵢ) * dxᵢ. This mirrors the VJP weights below but pushes an input tangent forward instead of pulling an output cotangent backward.

                  Instances For
                    def Spec.smoothMaxPool2dMultiJvpSpec {α : Type} [Context α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (beta : α) (input tangent : MultiChannelImage inC inH inW α) :
                    MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                    Multi-channel JVP for smooth max-pooling (channel-wise application).

                    Instances For
                      def Spec.maxPool2dSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input : Image inH inW α) :
                      Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                      MaxPool2d forward pass (single-channel).

                      This takes the maximum over each kH×kW window sampled with the given stride. The return type encodes the standard output spatial size formula.

                      Instances For
                        def Spec.maxPool2dMultiSpec {α : Type} [Context α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input : MultiChannelImage inC inH inW α) :
                        MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                        MaxPool2d forward pass (multi-channel): apply max_pool2d_spec per channel.

                        Instances For
                          def Spec.maxPool2dJvpSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input tangent : Image inH inW α) :
                          Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                          Forward-mode JVP for hard max-pooling (single-channel).

                          The tangent is read at the argmax chosen by the primal input. At ties the first row-major maximizer is used, matching maxPool2dBackwardSpec and PyTorch's index convention.

                          Instances For
                            def Spec.maxPool2dMultiJvpSpec {α : Type} [Context α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input tangent : MultiChannelImage inC inH inW α) :
                            MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                            Multi-channel JVP for hard max-pooling (channel-wise application).

                            Instances For
                              def Spec.avgPool2dSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : AvgPool2DSpec kH kW stride h1 h2 hStride) (input : Image inH inW α) :
                              Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                              AvgPool2d forward pass (single-channel).

                              We sum all values in the window and divide by kH*kW. PyTorch analogy: avg_pool2d with count_include_pad=true only matters for padded pooling; for the unpadded case it matches the usual definition.

                              Instances For
                                def Spec.avgPool2dMultiSpec {α : Type} [Context α] {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) {hStride : stride 0} (layer : AvgPool2DSpec kH kW stride h1 h2 hStride) (input : MultiChannelImage inC inH inW α) :
                                MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α

                                AvgPool2d forward pass (multi-channel): apply avg_pool2d_spec per channel.

                                Instances For
                                  structure Spec.AdaptiveAvgPool2DSpec (outH outW : ) :

                                  Spec record for adaptive average pooling to a fixed output size.

                                  • outputHeight :

                                    output Height.

                                  • outputWidth :

                                    output Width.

                                  Instances For
                                    structure Spec.AdaptiveMaxPool2DSpec (outH outW : ) :

                                    Spec record for adaptive max pooling to a fixed output size.

                                    • outputHeight :

                                      output Height.

                                    • outputWidth :

                                      output Width.

                                    Instances For

                                      Adaptive pooling #

                                      PyTorch defines adaptive pooling by partitioning the input into out bins. For output index i, the pooling region is:

                                      This matters when in is not divisible by out: region sizes vary by at most 1.

                                      def Spec.adaptiveStart (inSize outSize i : ) :

                                      Adaptive-pooling region start index: floor(i * in / out) (PyTorch definition).

                                      Instances For
                                        def Spec.adaptiveEnd (inSize outSize i : ) :

                                        Adaptive-pooling region end index: ceil((i+1) * in / out) (PyTorch definition).

                                        Instances For
                                          def Spec.adaptiveAvgPool2dSpec {α : Type} [Context α] {inH inW inC : } (outH outW : ) (_layer : AdaptiveAvgPool2DSpec outH outW) (input : MultiChannelImage inC inH inW α) (_h_inH : inH > 0 := by norm_num) (_h_inW : inW > 0 := by norm_num) (_h_outH : outH > 0 := by norm_num) (_h_outW : outW > 0 := by norm_num) :
                                          MultiChannelImage inC outH outW α

                                          AdaptiveAvgPool2d forward pass.

                                          Unlike fixed-kernel pooling, adaptive pooling chooses a window for each output position so that the whole input is covered by outH×outW bins. This follows the PyTorch start/end formula (see the section comment above).

                                          Instances For
                                            def Spec.adaptiveMaxPool2dSpec {α : Type} [Context α] {inH inW inC : } (outH outW : ) (_layer : AdaptiveMaxPool2DSpec outH outW) (input : MultiChannelImage inC inH inW α) (_h_inH : inH > 0 := by norm_num) (_h_inW : inW > 0 := by norm_num) (_h_outH : outH > 0 := by norm_num) (_h_outW : outW > 0 := by norm_num) :
                                            MultiChannelImage inC outH outW α

                                            AdaptiveMaxPool2d forward pass (same binning as adaptive avg, but with max instead of mean).

                                            We intentionally do not use a numeric sentinel value to seed the max fold; we seed from the first element of the region via getValueAtPosition. That keeps the spec meaningful across different scalar backends.

                                            Instances For
                                              def Spec.maxPool2dBackwardSpec {α : Type} [Context α] {kH kW inH inW stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (_layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input : Image inH inW α) (grad_output : Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α) :
                                              Image inH inW α

                                              Backward/VJP for max_pool2d_spec.

                                              This propagates each output gradient to the argmax location inside the corresponding window. Tie-breaking: if multiple values in the window are equal to the maximum, we keep the first position in row-major order (same convention as PyTorch's max-pool indices).

                                              Instances For
                                                def Spec.maxPool2dMultiBackwardSpec {α : Type} [Context α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} {hStride : stride 0} (layer : MaxPool2DSpec kH kW stride h1 h2 hStride) (input : MultiChannelImage inC inH inW α) (grad_output : MultiChannelImage inC ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α) :
                                                MultiChannelImage inC inH inW α

                                                Multi-channel max-pooling backward (channel-wise application of max_pool2d_backward_spec).

                                                Instances For
                                                  def Spec.avgPool2dBackwardSpec {α : Type} [Context α] {kH kW inH inW stride : } (_h1 : kH 0) (_h2 : kW 0) {hStride : stride 0} (_layer : AvgPool2DSpec kH kW stride _h1 _h2 hStride) (grad_output : Image ((inH - kH) / stride + 1) ((inW - kW) / stride + 1) α) :
                                                  Image inH inW α

                                                  Backward/VJP for avg_pool2d_spec (single-channel).

                                                  Each output gradient is evenly distributed across its corresponding input window.

                                                  Instances For