2D Pooling #
Single-channel and channels-first 2D max, average, adaptive, and smooth-max pooling specs.
MaxPool2d configuration.
The spec uses a fixed kernel (kH,kW) and a single stride value (applied to both height and width).
We require kH ≠ 0, kW ≠ 0, and stride ≠ 0 so windows are nonempty and the output-shape
arithmetic is well-defined.
PyTorch analogy: F.max_pool2d(x, kernel_size=(kH,kW), stride=stride).
Instances For
AvgPool2d configuration.
We treat the pooling window as kH*kW elements and divide by that count.
This corresponds to PyTorch's default behavior when no padding is present.
Instances For
Output shape for a 2D pooling op (single-channel) with no padding.
This uses the standard "valid" pooling formula:
outH = floor((inH - kH)/stride) + 1, outW = floor((inW - kW)/stride) + 1.
PyTorch analogy: ceil_mode=false with no padding.
Instances For
Output shape for multi-channel 2D pooling (channels preserved).
Instances For
Output shape for a 2D pooling op (single-channel) with symmetric padding.
padding means we use the usual PyTorch output-size formula for an input extended by padding
cells on each side. Hard max-pooling ignores padded cells (the PyTorch -∞ convention), while
average-pooling below explicitly includes padded zeros.
Instances For
Output shape for multi-channel 2D pooling with symmetric padding (channels preserved).
Instances For
Smooth max pooling #
max_pool2d_spec uses max and is non-differentiable (ties and kink points).
For proofs that need everywhere differentiability, we provide a smooth surrogate based on log-sum-exp over each pooling window:
smooth_max(x₁,…,xₙ) = (1 / β) * log (∑ exp (β * xᵢ))
This is the standard log-sum-exp surrogate and is intended for β ≠ 0.
Smooth max-pooling (single-channel) using a log-sum-exp surrogate.
This is useful in proof settings that want a differentiable alternative to max_pool2d_spec.
For large beta, the output approaches hard max pooling.
Instances For
Smooth max-pooling (multi-channel): apply smooth_max_pool2d_spec per channel.
Instances For
Forward-mode JVP for smooth max-pooling (single-channel).
For each pooling window this is the differential of the log-sum-exp surrogate,
Σᵢ softmax(beta*xᵢ) * dxᵢ. This mirrors the VJP weights below but pushes an input tangent
forward instead of pulling an output cotangent backward.
Instances For
Multi-channel JVP for smooth max-pooling (channel-wise application).
Instances For
MaxPool2d forward pass (single-channel).
This takes the maximum over each kH×kW window sampled with the given stride.
The return type encodes the standard output spatial size formula.
Instances For
MaxPool2d forward pass (multi-channel): apply max_pool2d_spec per channel.
Instances For
Forward-mode JVP for hard max-pooling (single-channel).
The tangent is read at the argmax chosen by the primal input. At ties the first row-major
maximizer is used, matching maxPool2dBackwardSpec and PyTorch's index convention.
Instances For
Multi-channel JVP for hard max-pooling (channel-wise application).
Instances For
AvgPool2d forward pass (single-channel).
We sum all values in the window and divide by kH*kW.
PyTorch analogy: avg_pool2d with count_include_pad=true only matters for padded pooling;
for the unpadded case it matches the usual definition.
Instances For
AvgPool2d forward pass (multi-channel): apply avg_pool2d_spec per channel.
Instances For
Adaptive pooling #
PyTorch defines adaptive pooling by partitioning the input into out bins.
For output index i, the pooling region is:
start = floor(i * in / out)end = ceil((i+1) * in / out)
This matters when in is not divisible by out: region sizes vary by at most 1.
Adaptive-pooling region start index: floor(i * in / out) (PyTorch definition).
Instances For
Adaptive-pooling region end index: ceil((i+1) * in / out) (PyTorch definition).
Instances For
AdaptiveAvgPool2d forward pass.
Unlike fixed-kernel pooling, adaptive pooling chooses a window for each output position so that
the whole input is covered by outH×outW bins.
This follows the PyTorch start/end formula (see the section comment above).
Instances For
AdaptiveMaxPool2d forward pass (same binning as adaptive avg, but with max instead of mean).
We intentionally do not use a numeric sentinel value to seed the max fold; we seed from the first
element of the region via getValueAtPosition. That keeps the spec meaningful across different
scalar backends.
Instances For
Backward/VJP for max_pool2d_spec.
This propagates each output gradient to the argmax location inside the corresponding window. Tie-breaking: if multiple values in the window are equal to the maximum, we keep the first position in row-major order (same convention as PyTorch's max-pool indices).
Instances For
Multi-channel max-pooling backward (channel-wise application of max_pool2d_backward_spec).
Instances For
Backward/VJP for avg_pool2d_spec (single-channel).
Each output gradient is evenly distributed across its corresponding input window.