TorchLean API

NN.Spec.Core.Tensor.Linalg

Linear algebra primitives (spec layer) #

This file defines the basic matrix/vector operations used across the model zoo:

All operations are shape-indexed in their types, so misuse is caught by elaboration.

These are intentionally simple, “obvious” definitions (folding over List.finRange) so that:

PyTorch analogies:

Create an identity matrix (n x n).

Notes:

  • The n = 0 case is an empty matrix; it still exists as a well-typed tensor.
  • We use i.val == j.val rather than DecidableEq (Fin n) to keep the definition directly executable across backends.
Instances For
    def Spec.matMulSpec {α : Type} [Add α] [Mul α] [Zero α] {m n p : } (A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) (B : Tensor α (Shape.dim n (Shape.dim p Shape.scalar))) :

    Matrix multiplication (m x n) @ (n x p) = (m x p).

    This is the simplest definitional version: sum over the shared n dimension. For performance-oriented runtime code, use the runtime layer; this spec is about clarity and proofs.

    Instances For
      def Spec.matVecMulSpec {α : Type} [Add α] [Mul α] [Zero α] {m n : } (A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) (v : Tensor α (Shape.dim n Shape.scalar)) :

      Matrix-vector multiplication (m x n) @ (n) = (m).

      Instances For
        def Spec.vecMatMulSpec {α : Type} [Add α] [Mul α] [Zero α] {m n : } (v : Tensor α (Shape.dim m Shape.scalar)) (A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) :

        Vector-matrix multiplication (m) @ (m x n) = (n).

        Instances For

          Outer product (m) otimes (n) = (m x n).

          Instances For