GraphM Pooling Ops #
N-dimensional and two-dimensional pooling builders with forward, JVP, and VJP payloads.
N-D max pooling (channels-first) on a single sample tensor (no batch axis).
PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on
the spatial rank d.
Forward-mode status: implemented. The JVP follows the primal argmax selected by
Spec.maxPoolJvpSpec, including the documented first-winner tie convention.
Instances For
N-D average pooling (channels-first) on a single sample tensor (no batch axis).
PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on
the spatial rank d.
Forward-mode status: implemented. Average pooling is linear, so the JVP is the same average-pool map applied to the input tangent.
Instances For
N-D smooth max pooling (log-sum-exp surrogate) on a single sample tensor (no batch axis).
PyTorch comparison: there is no direct primitive; this is a differentiable approximation to max pooling.
Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.
Instances For
2D max-pooling (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.max_pool2d (without a batch dimension).
Forward-mode status: implemented. The JVP routes each output tangent through the argmax selected by the primal input.
Instances For
2D max-pooling with explicit padding.
PyTorch comparison: torch.nn.functional.max_pool2d with padding.
Forward-mode status: implemented. Padding is fixed and the JVP follows the real primal winner, ignoring padded cells just like the forward pass.
Instances For
Smooth (soft) max-pooling, controlled by beta.
This is a differentiable approximation to max-pooling.
Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.
Instances For
Average pooling (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.avg_pool2d (without a batch dimension).
Forward-mode status: implemented. Average pooling is linear, so the JVP is average pooling of the input tangent.
Instances For
Average pooling with explicit padding.
PyTorch comparison: torch.nn.functional.avg_pool2d with padding.
Forward-mode status: implemented. Padding is fixed and average pooling is linear, so the JVP is the padded average-pool map applied to the input tangent.