TorchLean API

NN.Runtime.Autograd.Compiled.IRExec.Correctness.Ops.Random

Random #

Correctness lemmas for random IR nodes in the IR -> compiled runtime bridge.

These lemmas keep the end-to-end semantic equivalence proof in Correctness.SemanticEquivalence small: the top-level proof can dispatch to branch theorems, while this file checks branch-specific compiler and evaluator behavior.

Build note: the random operators are deterministic in the semantics once the seed and node id are fixed. The proof still has to show that the compiler and IR evaluator derive the same key, append a value of the same dependent shape, and continue with the same tail graph. Seed/key helper lemmas keep additional deterministic random primitives mechanical.

Main definitions #

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_rand_uniform {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (seed : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.randUniform seed) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .randUniform seed lowering.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_bernoulli_mask {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (seed : ) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.bernoulliMask seed) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for .bernoulliMask seed lowering.