TorchLean API

NN.Runtime.Autograd.Compiled.IRExec.Correctness.Ops.Reductions

Reductions and Broadcasting #

Correctness lemmas for IR nodes whose primary behavior is broadcasting or reduction:

Each lemma matches the compiler control flow closely: we validate the parent structure and the side-condition checks that buildFrom enforces, then construct the compiled forward closure and show that it matches NN.IR.Graph.evalAt at the current node. We finish by appealing to the shared buildFrom_denoteAllFrom_finish lemma for the tail of the graph.

Build note: reductions are among the more expensive op proofs because axes change shapes. Lean has to track both the input and output shapes, normalize the axis-side conditions, and then compare the compiled reduction with the IR denotation. Axis/shape arithmetic belongs in small lemmas so the semantic proof can read more like the compiler code.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_broadcastTo {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (s₁ s₂ : Spec.Shape) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.broadcastTo s₁ s₂) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .broadcastTo s₁ s₂ lowering.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_reduceSum {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (axis : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.reduceSum axis) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .reduceSum axis lowering.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_reduceMean {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (axis : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.reduceMean axis) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .reduceMean axis lowering.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_sum {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.sum) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .sum lowering (sum-reduction to scalar).