TorchLean API

NN.Runtime.Autograd.Compiled.IRExec.Correctness.Ops.LinearAlgebra

Linear Algebra #

Linear-algebra correctness lemmas for the IR → compiled runtime bridge.

This file proves the forward-correctness step for compiling a .matmul IR node into a single SSA node in the compiled GraphData. Concretely, it shows that:

then the value appended by the IR evaluator at step i is definitionally the same tensor as the value produced by the compiled node's forward.

We handle the two shape cases supported by TorchLean's current matmul compiler:

This module is about semantic correctness. Performance backends (e.g. external BLAS libraries, kernel fusion, etc.) are a separate lowering layer and are not involved here.

The shape rules match the public PyTorch APIs for matrix multiply and batched matrix multiply:

Main definitions #

Implementation notes #

Tags #

matmul, bmm, correctness, ir, runtime

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_matmul_mm_success {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.matmul) (hi : i < g.nodes.size) (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) (aId bId a0 a1 b1 : ) (hp : n.parents = [aId, bId]) (ia : Proofs.Autograd.Algebra.Idx ([inShape] ++ ss) (Spec.Shape.dim a0 (Spec.Shape.dim a1 Spec.Shape.scalar))) (hIa : IRExec.mkIdx inShape ss aId (Spec.Shape.dim a0 (Spec.Shape.dim a1 Spec.Shape.scalar)) = Except.ok ia) (ib : Proofs.Autograd.Algebra.Idx ([inShape] ++ ss) (Spec.Shape.dim a1 (Spec.Shape.dim b1 Spec.Shape.scalar))) (hIb : IRExec.mkIdx inShape ss bId (Spec.Shape.dim a1 (Spec.Shape.dim b1 Spec.Shape.scalar)) = Except.ok ib) (hOut : Spec.Shape.dim a0 (Spec.Shape.dim b1 Spec.Shape.scalar) = n.outShape) (hBuildNext : IRExec.buildFrom g payload inShape (i + 1) ss ++ [n.outShape], gd.snoc (IRExec.mkFwdNode fun (ctx : Proofs.Autograd.Algebra.TList α ([inShape] ++ ss)) => have aT := Proofs.Autograd.Algebra.getIdx ctx ia; have bT := Proofs.Autograd.Algebra.getIdx ctx ib; have y := Spec.matMulSpec aT bT; hOut y) = Except.ok st') :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .matmul compilation step in the 2D matrix-multiply case.

This is used when the parent shapes match Spec.mat_mul_spec (no batch dimension), and gives the exact equality needed to hand off to the tail-induction hypothesis.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_matmul_bmm_success {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.matmul) (hi : i < g.nodes.size) (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) (aId bId a0 a1 a2 bP : ) (hp : n.parents = [aId, bId]) (ia : Proofs.Autograd.Algebra.Idx ([inShape] ++ ss) (Spec.Shape.dim a0 (Spec.Shape.dim a1 (Spec.Shape.dim a2 Spec.Shape.scalar)))) (hIa : IRExec.mkIdx inShape ss aId (Spec.Shape.dim a0 (Spec.Shape.dim a1 (Spec.Shape.dim a2 Spec.Shape.scalar))) = Except.ok ia) (ib : Proofs.Autograd.Algebra.Idx ([inShape] ++ ss) (Spec.Shape.dim a0 (Spec.Shape.dim a2 (Spec.Shape.dim bP Spec.Shape.scalar)))) (hIb : IRExec.mkIdx inShape ss bId (Spec.Shape.dim a0 (Spec.Shape.dim a2 (Spec.Shape.dim bP Spec.Shape.scalar))) = Except.ok ib) (hOut : Spec.Shape.dim a0 (Spec.Shape.dim a1 (Spec.Shape.dim bP Spec.Shape.scalar)) = n.outShape) (hBuildNext : IRExec.buildFrom g payload inShape (i + 1) ss ++ [n.outShape], gd.snoc (IRExec.mkFwdNode fun (ctx : Proofs.Autograd.Algebra.TList α ([inShape] ++ ss)) => have aT := Proofs.Autograd.Algebra.getIdx ctx ia; have bT := Proofs.Autograd.Algebra.getIdx ctx ib; have y := aT.bmmSpec bT; hOut y) = Except.ok st') :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .matmul compilation step in the batched-matmul (bmm) case.

This is used when the parent shapes match Tensor.bmm_spec with an explicit batch dimension, and again yields the exact one-step equality consumed by the semantic equivalence skeleton.

theorem Runtime.Autograd.Compiled.buildFrom_denoteAllFrom_matmul {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) {inShape : Spec.Shape} {ss : List Spec.Shape} (gd : Proofs.Autograd.Algebra.GraphData α Unit [inShape] ss) (i : ) (st' : IRExec.State α inShape) (x : Spec.Tensor α inShape) (n : NN.IR.Node) (hN : g.getNode i = Except.ok n) (hk : n.kind = NN.IR.OpKind.matmul) (hi : i < g.nodes.size) (hBuild : IRExec.buildFrom g payload inShape i ss, gd = Except.ok st') (ih : ∀ (st1 : IRExec.State α inShape), IRExec.buildFrom g payload inShape (i + 1) st1 = Except.ok st'g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) (i + 1) (denoteAllState inShape st1 x) = Except.ok (denoteAllState inShape st' x)) :
g.denoteAllFrom payload (NN.IR.DVal.mk inShape x) i (denoteAllState inShape ss, gd x) = Except.ok (denoteAllState inShape st' x)

Correctness lemma for the .matmul node compiler.

This is a dispatcher that selects the appropriate specialized lemma: