Loss #
Loss-function correctness lemmas for the IR → compiled runtime bridge.
The IR node kind .mse_loss is compiled into an SSA node whose forward computes the
specification-level mean squared error loss.
This file proves the forward-correctness lemma for that compilation step: on successful compilation
at position i, the IR evaluator NN.IR.Graph.denoteAllFrom and the compiled evaluator
denoteAllState append the same result.
This is a structural correctness statement: it is not about generalization, training convergence, or statistical properties of MSE; it simply connects the IR semantics to the compiled runtime node.
Main definitions #
buildFrom_denoteAllFrom_mse_loss: correctness step for.mse_losslowering.
Implementation notes #
- We keep this theorem in a dedicated file because it is heavier than most per-op steps.
- The proof structure follows the compiler's guard sequence; this keeps dependent shape checks aligned with compiler guards.
- This file can build slowly because MSE touches two parents, a scalar output shape, and a sequence of compiler guards. Repeated guard eliminations belong in focused helper lemmas, leaving the theorem focused on the loss equation itself.
References #
Tags #
mse-loss, correctness, ir, runtime, semantic equivalence
theorem
Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_mse_loss
{α : Type}
[Context α]
[DecidableEq Spec.Shape]
(g : NN.IR.Graph)
(payload : NN.IR.Payload α)
{inShape : Spec.Shape}
{ss : List Spec.Shape}
(gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss)
(i : ℕ)
(st' : IRExec.State α inShape)
(x : Spec.Tensor α inShape)
(n : NN.IR.Node)
(hN : g.getNode i = Except.ok n)
(hk : n.kind = NN.IR.OpKind.mseLoss)
(hi : i < g.nodes.size)
(hBuild : IRExec.buildFrom g payload inShape i ⟨ss, gd⟩ = Except.ok st')
(ih :
∀ (st1 : IRExec.State α inShape),
IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st' →
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x))
:
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ⟨ss, gd⟩ x) = Except.ok (denoteAllState inShape st' x)
Correctness lemma for the .mse_loss node compiler.