TorchLean API

NN.Runtime.Autograd.Compiled.IRExec.Correctness.Ops.Loss

Loss #

Loss-function correctness lemmas for the IR → compiled runtime bridge.

The IR node kind .mse_loss is compiled into an SSA node whose forward computes the specification-level mean squared error loss.

This file proves the forward-correctness lemma for that compilation step: on successful compilation at position i, the IR evaluator NN.IR.Graph.denoteAllFrom and the compiled evaluator denoteAllState append the same result.

This is a structural correctness statement: it is not about generalization, training convergence, or statistical properties of MSE; it simply connects the IR semantics to the compiled runtime node.

Main definitions #

Implementation notes #

References #

Tags #

mse-loss, correctness, ir, runtime, semantic equivalence

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_mse_loss {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.mseLoss) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .mse_loss node compiler.