TorchLean API

NN.Runtime.Autograd.Compiled.IRExec.Correctness.Ops.Pooling

Pooling #

Pooling correctness lemmas for the IR → compiled runtime bridge.

TorchLean models pooling on rank-3 image tensors C × H × W (and batched variants upstream in the runtime), matching the usual PyTorch max-pooling convention: torch.nn.functional.max_pool2d / torch.nn.MaxPool2d.

This file proves the forward-correctness statement for the compiler path that lowers pooling IR nodes into a single SSA node whose forward computes the corresponding spec-level pooling operation. Concretely, successful compilation at graph index i implies that the IR evaluator NN.IR.Graph.denoteAllFrom and the compiled evaluator denoteAllState stay in semantic equivalence.

References:

Main definitions #

Implementation notes #

Tags #

pool2d, correctness, ir, runtime, semantic equivalence

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_max_pool2d {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (kH kW stride : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.maxPool2d kH kW stride) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .max_pool2d node compiler (no padding).

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_max_pool2d_pad {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (kH kW stride padding : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.maxPool2dPad kH kW stride padding) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .max_pool2d_pad node compiler (explicit padding).

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_avg_pool2d {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (kH kW stride : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.avgPool2d kH kW stride) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .avg_pool2d node compiler (no padding).

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_avg_pool2d_pad {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (kH kW stride padding : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.avgPool2dPad kH kW stride padding) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .avg_pool2d_pad node compiler (explicit padding).