Linear algebra primitives (spec layer) #
This file defines the basic matrix/vector operations used across the model zoo:
matMulSpec(matrix × matrix)matVecMulSpec(matrix × vector)vecMatMulSpec(vector × matrix)outerProductSpec
All operations are shape-indexed in their types, so misuse is caught by elaboration.
These are intentionally simple, “obvious” definitions (folding over List.finRange) so that:
- they are easy to reason about in proofs, and
- they can be instantiated over many scalar backends (
Float,ℚ,IEEE32Exec,ℝ, …).
PyTorch analogies:
matMulSpec A BisA @ BmatVecMulSpec A visA @ vvecMatMulSpec v Aisv @ AouterProductSpec a bis likea.unsqueeze(1) * b.unsqueeze(0)(result is(m,n)).
def
Spec.identityTensorSpec
{α : Type}
[Zero α]
[One α]
(n : ℕ)
:
Tensor α (Shape.dim n (Shape.dim n Shape.scalar))
Create an identity matrix (n x n).
Notes:
- The
n = 0case is an empty matrix; it still exists as a well-typed tensor. - We use
i.val == j.valrather thanDecidableEq (Fin n)to keep the definition directly executable across backends.
Instances For
def
Spec.matMulSpec
{α : Type}
[Add α]
[Mul α]
[Zero α]
{m n p : ℕ}
(A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar)))
(B : Tensor α (Shape.dim n (Shape.dim p Shape.scalar)))
:
Tensor α (Shape.dim m (Shape.dim p Shape.scalar))
Matrix multiplication (m x n) @ (n x p) = (m x p).
This is the simplest definitional version: sum over the shared n dimension.
For performance-oriented runtime code, use the runtime layer; this spec is about clarity and proofs.
Instances For
def
Spec.matVecMulSpec
{α : Type}
[Add α]
[Mul α]
[Zero α]
{m n : ℕ}
(A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar)))
(v : Tensor α (Shape.dim n Shape.scalar))
:
Tensor α (Shape.dim m Shape.scalar)
Matrix-vector multiplication (m x n) @ (n) = (m).
Instances For
def
Spec.vecMatMulSpec
{α : Type}
[Add α]
[Mul α]
[Zero α]
{m n : ℕ}
(v : Tensor α (Shape.dim m Shape.scalar))
(A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar)))
:
Tensor α (Shape.dim n Shape.scalar)
Vector-matrix multiplication (m) @ (m x n) = (n).
Instances For
def
Spec.outerProductSpec
{α : Type}
[Mul α]
{m n : ℕ}
(a : Tensor α (Shape.dim m Shape.scalar))
(b : Tensor α (Shape.dim n Shape.scalar))
:
Tensor α (Shape.dim m (Shape.dim n Shape.scalar))
Outer product (m) otimes (n) = (m x n).