Common #
Internal helper lemmas for NN.Runtime.Autograd.Compiled.IRExec.Correctness.
These lemmas relate the typed runtime context (TList) to the untyped IR value table (Array DVal),
and provide small “building block” correctness steps that are reused across the per-op proofs.
Reading map:
dValsOfCtx*lemmas: relate the typed context produced byGraphData.evalto an untypedArray (NN.IR.DVal α)(this is what the IR evaluator uses).denoteAllState*lemmas: package the compiled evaluator (ExecGraphData.denoteAll) in the form expected by IR-style semantic equivalence proofs.
These lemmas are infrastructure: they should not encode op-specific logic. Per-op correctness files (Matmul/Pool2d/LayerNorm/MSELoss) should depend on this module and not re-prove these bridges.
Main definitions #
throw_bind_ne_ok: eliminates impossible success branches afterthrow.NoMSELoss: side condition for semantic equivalence theorems that intentionally exclude.mse_loss.dValsOfCtx_*: typed-context to IR-array bridge lemmas.denoteAllState_*helpers: semantic equivalence bridges between compiled state and IR denotation tables.
Implementation notes #
- We keep this module deliberately stable: it is shared infrastructure, so predictable proof contracts matter more than clever proof tricks.
- Many lemmas here are proof-irrelevance/indexing bridges; these are repetitive but they remove a lot of friction from op-specific proofs.
- Collecting these utilities in one place avoids duplicated brittle simp chains across correctness modules.
- These files can build slowly because they connect two representations at once: typed
TListcontexts on the compiled side and dynamically shapedDValarrays on the IR side. Most of the cost is not arithmetic; it is Lean checking that shape casts, array indices, and proof-irrelevant casts line up exactly. - When the same proof pattern appears in
multiple operator files. The best long-term direction is fewer ad-hoc
simpscripts in the op proofs and more named lemmas with clear contracts.
Tags #
correctness, infrastructure, tlist, dval, bridge-lemmas
throw msg in the Except monad is .error msg.
Shared side conditions #
These predicates describe the exact fragment covered by a theorem. Keeping them in Common lets
the per-op lemmas, the semantic equivalence proof, and the chapter index refer to the same public contract
without import cycles.
Core semantic equivalence side condition: the IR graph contains no .mse_loss nodes.
The compiled runtime has a correct .mse_loss step lemma in Correctness.Ops.Loss; the existing
end-to-end semantic equivalence theorem keeps this condition so its branch proof stays small and predictable.
Instances For
If a do-chain begins with throw, it cannot produce an .ok result.
This is a small convenience lemma used throughout the compiled-correctness proofs to close impossible branches where compilation would have thrown an error message.
Array indexing is proof-irrelevant.
This is a small technical lemma: in Lean, xs[i]'h carries a proof h : i < xs.size. Different
proofs should not change the value returned by indexing.
Relate xs[i]! (defaulting lookup) and xs[i]'h (bounded lookup) when the index is in-bounds.
This is a small bridge lemma used throughout the IR/runtime context comparison proofs.
dValsOfCtx ignores type-level casts of the underlying TList.
GraphData.eval introduces a definitional cast when extending contexts; this lemma lets us erase it
before reasoning about the corresponding Array of DVals.
dValsOfCtx for a snoc’d context corresponds to Array.push of the appended tensor.
Indexing dValsOfCtx agrees with indexing the underlying TList context.
This is the main bridge between the typed runtime context and the untyped IR value table.
Indexing dValsOfCtx by a typed Idx agrees with getIdx on the underlying TList.
This packages dValsOfCtx_getElem! into the repository’s Idx wrapper.
Graph.expectShape succeeds on a DVal built with the same shape.
Same as Graph.expectShape_mk, but for the sigma-style constructor ⟨s, t⟩.
NN.IR.Graph.evalAt for a .matmul node specialized to 2D matrix multiply.
This is a proof-only helper that records the exact Spec.mat_mul_spec term produced by evalAt
in the well-typed success case.
NN.IR.Graph.evalAt for a .matmul node specialized to batched matmul (bmm).
Like evalAt_matmul_mm_ok, this is used to relate the IR evaluator’s result to the compiled node’s
forward closure during the semantic equivalence correctness proof.
NN.IR.Graph.evalAt for a .reduceSum axis node, specialized to a well-typed success case.
This helper records the exact Tensor.reduceSum term produced by the IR evaluator once:
- the parent has the expected shape
s, - the axis validity check succeeds, and
- the node's declared
outShapematchesshapeAfterSum s axis.
The final cast to n.outShape comes from the evalAt "shape-tag normalization" step.
NN.IR.Graph.evalAt for a .reduceMean axis node, specialized to a well-typed success case.
This is the mean analogue of evalAt_reduceSum_ok.
Repackage a compiled State as an ExecGraphData so we can call its evaluator helpers.
Instances For
Evaluate the compiled prefix state and convert its typed runtime context into an IR-style table.
Instances For
denoteAllState commutes with extending the SSA graph by one node (GraphData.snoc).
This is the key step for proving that the compiler’s prefix-building loop stays in semantic equivalence with the IR denotation table.
Build a typed runtime index (Idx) for a numeric IR parent id.
The compiled runtime context is typed by a list of shapes [inShape] ++ ss. mkIdx checks that:
idis in bounds, and- the context shape at that position matches the expected shape
s.
Lookup lemma: denoteAllState[..][pid]! agrees with getIdx when mkIdx pid s succeeds.
This is used when proving correctness of the per-node compiler step: we translate parent ids in the IR into typed indices into the compiled context.
One-step finishing lemma for the buildFrom/denoteAllFrom semantic equivalence proof.
If we know:
- the tail recursion
i+1is correct (hTail), - the IR evaluator step at
imatches the compiled node’sforward(hEval), and - the compiled table at
iis the previous table plus the pushed node value (hStep), thendenoteAllFromatireturns the final compiled table.