Pooling #
Pooling correctness lemmas for the IR → compiled runtime bridge.
TorchLean models pooling on rank-3 image tensors C × H × W (and batched variants upstream in the
runtime), matching the usual PyTorch max-pooling convention:
torch.nn.functional.max_pool2d / torch.nn.MaxPool2d.
This file proves the forward-correctness statement for the compiler path that lowers pooling IR
nodes into a single SSA node whose forward computes the corresponding spec-level pooling
operation. Concretely, successful compilation at graph index
i implies that the IR evaluator NN.IR.Graph.denoteAllFrom and the compiled evaluator
denoteAllState stay in semantic equivalence.
References:
- PyTorch functional max-pool docs: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.max_pool2d
- PyTorch module max-pool docs: https://docs.pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
Main definitions #
buildFrom_denoteAllFrom_max_pool2d: correctness step for unpadded max-pool lowering.buildFrom_denoteAllFrom_max_pool2d_pad: correctness step for padded max-pool lowering.buildFrom_denoteAllFrom_avg_pool2d: correctness step for unpadded average-pool lowering.buildFrom_denoteAllFrom_avg_pool2d_pad: correctness step for padded average-pool lowering.
Implementation notes #
- The proof follows compiler control flow closely (shape checks, parent checks, guard conditions); this one-to-one structure keeps maintenance direct when lowering rules evolve.
- Impossible branches are discharged early via
throw_bind_ne_ok, which keeps the success path readable. - Pooling proofs build slowly because the output height and width are computed from kernel, stride, and padding parameters, then reflected in dependent tensor shapes. Focused helper lemmas should isolate those output-shape arithmetic facts so the semantic proof only compares the compiled pooling call with the IR pooling evaluator.
Tags #
pool2d, correctness, ir, runtime, semantic equivalence
Correctness lemma for the .max_pool2d node compiler (no padding).
Correctness lemma for the .max_pool2d_pad node compiler (explicit padding).
Correctness lemma for the .avg_pool2d node compiler (no padding).
Correctness lemma for the .avg_pool2d_pad node compiler (explicit padding).