Linear Algebra #
Linear-algebra correctness lemmas for the IR → compiled runtime bridge.
This file proves the forward-correctness step for compiling a .matmul IR node into a single SSA
node in the compiled GraphData. Concretely, it shows that:
- if
buildFromsuccessfully compiles a.matmulnode at positioni, and - we compare the IR evaluator
NN.IR.Graph.denoteAllFromagainst the compiled evaluatordenoteAllState,
then the value appended by the IR evaluator at step i is definitionally the same tensor as the
value produced by the compiled node's forward.
We handle the two shape cases supported by TorchLean's current matmul compiler:
- matrix multiply (2D):
Spec.mat_mul_spec, - batched matrix multiply (3D):
Tensor.bmm_spec.
This module is about semantic correctness. Performance backends (e.g. external BLAS libraries, kernel fusion, etc.) are a separate lowering layer and are not involved here.
The shape rules match the public PyTorch APIs for matrix multiply and batched matrix multiply:
torch.matmul: https://pytorch.org/docs/stable/generated/torch.matmul.htmltorch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html
Main definitions #
buildFrom_denoteAllFrom_matmul_mm_success: correctness step for 2D matmul.buildFrom_denoteAllFrom_matmul_bmm_success: correctness step for batched matmul.
Implementation notes #
- We explicitly split 2D and 3D cases, which makes shape distinctions visible instead of hiding them behind one very generic theorem.
- The theorem shape follows compiler control flow so failed typing/shape branches collapse quickly.
- Matmul proofs can be slow because the compiler has to distinguish 2D matrix multiplication from batched 3D multiplication while preserving exact type-level shapes. The shape-dispatch lemmas separate from the semantic equality proof.
Tags #
matmul, bmm, correctness, ir, runtime
Correctness lemma for the .matmul compilation step in the 2D matrix-multiply case.
This is used when the parent shapes match Spec.mat_mul_spec (no batch dimension), and gives the
exact equality needed to hand off to the tail-induction hypothesis.
Correctness lemma for the .matmul compilation step in the batched-matmul (bmm) case.
This is used when the parent shapes match Tensor.bmm_spec with an explicit batch dimension, and
again yields the exact one-step equality consumed by the semantic equivalence skeleton.
Correctness lemma for the .matmul node compiler.
This is a dispatcher that selects the appropriate specialized lemma:
buildFrom_denoteAllFrom_matmul_mm_successfor 2D matmul, orbuildFrom_denoteAllFrom_matmul_bmm_successfor batched matmul.