Pooling layers (spec layer) #
This file defines a small set of pooling operators in TorchLean's spec layer.
PyTorch analogies:
max_pool2d_*corresponds totorch.nn.functional.max_pool2d.avg_pool2d_*corresponds totorch.nn.functional.avg_pool2d.adaptive_*corresponds totorch.nn.functional.adaptive_{avg,max}_pool2d(output size fixed, pooling regions vary per output position).
We also include a smooth log-sum-exp surrogate for max pooling. It is useful when you want an everywhere-differentiable approximation for proofs or analysis, without changing the rest of the pooling API.
MaxPool2d configuration.
The spec uses a fixed kernel (kH,kW) and a single stride value (applied to both height and width).
We require kH ≠ 0, kW ≠ 0, and stride ≠ 0 so windows are nonempty and the output-shape
arithmetic is well-defined.
PyTorch analogy: F.max_pool2d(x, kernel_size=(kH,kW), stride=stride).
Instances For
AvgPool2d configuration.
We treat the pooling window as kH*kW elements and divide by that count.
This corresponds to PyTorch's default behavior when no padding is present.
Instances For
Output shape for a 2D pooling op (single-channel) with no padding.
This uses the standard "valid" pooling formula:
outH = floor((inH - kH)/stride) + 1, outW = floor((inW - kW)/stride) + 1.
PyTorch analogy: ceil_mode=false with no padding.
Instances For
Output shape for multi-channel 2D pooling (channels preserved).
Instances For
Output shape for a 2D pooling op (single-channel) with symmetric padding.
padding means we use the usual PyTorch output-size formula for an input extended by padding
cells on each side. Hard max-pooling ignores padded cells (the PyTorch -∞ convention), while
average-pooling below explicitly includes padded zeros.
Instances For
Output shape for multi-channel 2D pooling with symmetric padding (channels preserved).
Instances For
Smooth max pooling #
max_pool2d_spec uses max and is non-differentiable (ties and kink points).
For proofs that need everywhere differentiability, we provide a smooth surrogate based on log-sum-exp over each pooling window:
smooth_max(x₁,…,xₙ) = (1 / β) * log (∑ exp (β * xᵢ))
This is the standard log-sum-exp surrogate and is intended for β ≠ 0.
Smooth max-pooling (single-channel) using a log-sum-exp surrogate.
This is useful in proof settings that want a differentiable alternative to max_pool2d_spec.
For large beta, the output approaches hard max pooling.
Instances For
Smooth max-pooling (multi-channel): apply smooth_max_pool2d_spec per channel.
Instances For
Forward-mode JVP for smooth max-pooling (single-channel).
For each pooling window this is the differential of the log-sum-exp surrogate,
Σᵢ softmax(beta*xᵢ) * dxᵢ. This mirrors the VJP weights below but pushes an input tangent
forward instead of pulling an output cotangent backward.
Instances For
Multi-channel JVP for smooth max-pooling (channel-wise application).
Instances For
MaxPool2d forward pass (single-channel).
This takes the maximum over each kH×kW window sampled with the given stride.
The return type encodes the standard output spatial size formula.
Instances For
MaxPool2d forward pass (multi-channel): apply max_pool2d_spec per channel.
Instances For
Forward-mode JVP for hard max-pooling (single-channel).
The tangent is read at the argmax chosen by the primal input. At ties the first row-major
maximizer is used, matching maxPool2dBackwardSpec and PyTorch's index convention.
Instances For
Multi-channel JVP for hard max-pooling (channel-wise application).
Instances For
AvgPool2d forward pass (single-channel).
We sum all values in the window and divide by kH*kW.
PyTorch analogy: avg_pool2d with count_include_pad=true only matters for padded pooling;
for the unpadded case it matches the usual definition.
Instances For
AvgPool2d forward pass (multi-channel): apply avg_pool2d_spec per channel.
Instances For
Adaptive pooling #
PyTorch defines adaptive pooling by partitioning the input into out bins.
For output index i, the pooling region is:
start = floor(i * in / out)end = ceil((i+1) * in / out)
This matters when in is not divisible by out: region sizes vary by at most 1.
Adaptive-pooling region start index: floor(i * in / out) (PyTorch definition).
Instances For
Adaptive-pooling region end index: ceil((i+1) * in / out) (PyTorch definition).
Instances For
AdaptiveAvgPool2d forward pass.
Unlike fixed-kernel pooling, adaptive pooling chooses a window for each output position so that
the whole input is covered by outH×outW bins.
This follows the PyTorch start/end formula (see the section comment above).
Instances For
AdaptiveMaxPool2d forward pass (same binning as adaptive avg, but with max instead of mean).
We intentionally do not use a numeric sentinel value to seed the max fold; we seed from the first
element of the region via getValueAtPosition. That keeps the spec meaningful across different
scalar backends.
Instances For
Backward/VJP for max_pool2d_spec.
This propagates each output gradient to the argmax location inside the corresponding window. Tie-breaking: if multiple values in the window are equal to the maximum, we keep the first position in row-major order (same convention as PyTorch's max-pool indices).
Instances For
Multi-channel max-pooling backward (channel-wise application of max_pool2d_backward_spec).
Instances For
Backward/VJP for avg_pool2d_spec (single-channel).
Each output gradient is evenly distributed across its corresponding input window.
Instances For
Padded pooling (symmetric padding) #
For max-pooling, padded locations are not real input elements and are ignored when selecting the
maximum. This is the scalar-polymorphic way to model PyTorch's -∞ max-pool padding without adding
a backend-specific infinity constant to Context α.
For average pooling, this corresponds to including padded zeros in the average (PyTorch's default
count_include_pad = true).
Remove symmetric zero-padding from a multi-channel image (channel-wise unpad_image).
Instances For
Multi-channel max-pooling forward pass with PyTorch-style padding (-∞ outside bounds).
Instances For
Forward-mode JVP for padded hard max-pooling.
Padding cells are ignored exactly as in maxPool2dMultiSpecPad, so the tangent is taken from the
primal winner among real input locations only. If a window contains no real input cells, the
forward value and tangent are both 0.
Instances For
Multi-channel average pooling forward pass with symmetric zero padding.
Instances For
Multi-channel max-pooling backward pass with PyTorch-style padding (-∞ outside bounds).
Instances For
Multi-channel average-pooling backward pass with symmetric padding (backprop then unpad).
Instances For
Backward/VJP for smooth_max_pool2d_spec (log-sum-exp surrogate).
Instances For
Multi-channel backward for smooth_max_pool2d_multi_spec (apply per channel).
Instances For
Generic N-D pooling (channels-first, no batch) #
These operators generalize the existing 2D pooling specs to an arbitrary spatial rank d.
Conventions:
- Input is channels-first: shape
[C] ++ spatialDims. - Pooling is applied independently per channel (like the existing 2D specs).
kernel,stride, andpaddingare per-axis vectors (Vector Nat d).- Padding is symmetric and uses zeros.
PyTorch comparisons (conceptual, without batch axis):
max_pool_speccorresponds totorch.nn.functional.max_poolNd.avg_pool_speccorresponds totorch.nn.functional.avg_poolNd.
Layer configs + output shapes #
Kernel/stride/padding configuration for N-D max pooling.
Kernel sizes per spatial axis (outermost to innermost).
Strides per spatial axis (outermost to innermost).
Symmetric zero padding per spatial axis (outermost to innermost).
Instances For
Kernel/stride/padding configuration for N-D average pooling.
Kernel sizes per spatial axis (outermost to innermost).
Strides per spatial axis (outermost to innermost).
Symmetric zero padding per spatial axis (outermost to innermost).
Instances For
Instances For
Input lookup for average/smooth pooling.
For average-style pooling, padded cells contribute numeric zero and are still counted by the
denominator chosen by the surrounding pooling spec. We keep this separate from
getPaddedMaxInputVal?, where padded cells must be ignored rather than treated as zero.
Instances For
Input lookup for hard max-pooling.
Unlike average-pooling, max-pooling should not insert a numeric zero for padded cells: PyTorch's
max-pool semantics treat padding as -∞. TorchLean keeps the spec scalar-polymorphic by returning
none for padded coordinates and letting the max fold ignore them.
Instances For
Directional derivative of hard max-pooling for one N-D window.
The derivative is taken along the same winner selected by maxPoolValue. At ties we keep the first
winner in row-major order, matching the VJP convention below and PyTorch's index convention.
Instances For
Directional derivative of the smooth log-sum-exp pooling value.
For y = beta⁻¹ log Σ exp(beta*xᵢ), the directional derivative is
Σ softmax(beta*xᵢ) * dxᵢ, using the same zero-padding convention as smoothMaxPoolValue.
Instances For
Forward (single-channel spatial tensor) #
N-D max pooling on a spatial tensor (no explicit channel axis).
Instances For
Forward-mode JVP for N-D hard max-pooling on a spatial tensor.
The derivative follows the same primal argmax as maxPoolSpatialSpec; at ties it keeps the first
row-major maximizer. This is the correct directional derivative for TorchLean's chosen subgradient
convention and matches the VJP tie policy.
Instances For
N-D average pooling on a spatial tensor (no explicit channel axis).
Instances For
Backward (single-channel spatial tensor) #
These are the VJPs of the forward pooling specs above.
Conventions:
- For max pooling, ties are broken by first occurrence in row-major order (same as the 2D spec).
- For max pooling, padded cells are ignored, modeling PyTorch's
-∞padding without requiring a scalar-polymorphic infinity constant. - For average pooling, gradients are evenly distributed across the full kernel window
(
count_include_pad=truebehavior when padding is present).
Backward/VJP for max_pool_spatial_spec.
Each output gradient is propagated to the argmax location in the corresponding input window. Ties keep the first position in row-major order.
Instances For
Backward/VJP for avg_pool_spatial_spec (single-channel).
Each output gradient is evenly distributed across its kernel window.
Instances For
Forward (channels-first: C × spatial...) #
N-D max pooling on a channels-first tensor: shape [C] ++ spatial.
Instances For
N-D hard max-pool JVP on a channels-first tensor (channel-wise application).
Instances For
N-D average pooling on a channels-first tensor: shape [C] ++ spatial.
Instances For
Backward (channels-first: C × spatial...) #
Multi-channel VJP for max_pool_spec (apply spatial backward per channel).
Instances For
Multi-channel VJP for avg_pool_spec (apply spatial backward per channel).
Instances For
Smooth max pooling (log-sum-exp surrogate) #
Smooth log-sum-exp max pooling on a spatial tensor (no explicit channel axis).
Instances For
Forward-mode JVP for N-D smooth max-pooling on a spatial tensor.
For the log-sum-exp surrogate this is the softmax-weighted sum of the input tangent over each
window. It is the forward-mode counterpart of smoothMaxPoolSpatialBackwardSpec.
Instances For
Smooth log-sum-exp max pooling on a channels-first tensor (channel-wise application).
Instances For
N-D smooth max-pool JVP on a channels-first tensor (channel-wise application).
Instances For
Smooth max pooling backward #
Backward/VJP for smooth_max_pool_spatial_spec (log-sum-exp surrogate).
For a window x₁,…,xₙ, the surrogate is:
y = (1/beta) * log(∑ exp(beta*xᵢ))
and the VJP distributes upstream gradient proportionally to exp(beta*xᵢ).
Instances For
Multi-channel VJP for smooth_max_pool_spec (apply spatial backward per channel).
Instances For
Friendly aliases #
Alias for max_pool_spec.
Instances For
Alias for avg_pool_spec.
Instances For
Alias for smooth_max_pool_spec.
Instances For
Alias for max_pool_backward_spec.
Instances For
Alias for avg_pool_backward_spec.
Instances For
Alias for smooth_max_pool_backward_spec.