Linear Algebra Helpers #
Matrix transpose, 3D transposes, matmul/bmm backward specs, and shape matching.
def
Spec.Tensor.matrixTransposeSpec
{α : Type}
{m n : ℕ}
(t : Tensor α (Shape.dim m (Shape.dim n Shape.scalar)))
:
Tensor α (Shape.dim n (Shape.dim m Shape.scalar))
Transpose a matrix (m×n) into (n×m).
PyTorch analogy: A.transpose(0, 1) or A.T for 2D tensors.
Instances For
def
Spec.Tensor.transpose3DFirstToLastSpec
{α : Type}
{a b c : ℕ}
(t : Tensor α (Shape.dim a (Shape.dim b (Shape.dim c Shape.scalar))))
:
Tensor α (Shape.dim b (Shape.dim c (Shape.dim a Shape.scalar)))
Permute a 3D tensor from (a,b,c) to (b,c,a).
Instances For
def
Spec.Tensor.transpose3DLastToFirstSpec
{α : Type}
{a b c : ℕ}
(t : Tensor α (Shape.dim a (Shape.dim b (Shape.dim c Shape.scalar))))
:
Tensor α (Shape.dim c (Shape.dim a (Shape.dim b Shape.scalar)))
Permute a 3D tensor from (a,b,c) to (c,a,b).
Instances For
def
Spec.Tensor.transpose3DLastTwoSpec
{α : Type}
{a b c : ℕ}
(t : Tensor α (Shape.dim a (Shape.dim b (Shape.dim c Shape.scalar))))
:
Tensor α (Shape.dim a (Shape.dim c (Shape.dim b Shape.scalar)))
Swap the last two axes of a 3D tensor: (a,b,c) to (a,c,b).
Instances For
def
Spec.Tensor.swapAtDepthHelper
{β : Type}
{shape : Shape}
(tensor : Tensor β shape)
(d : ℕ)
:
Tensor β (shape.swapAdjacentAtDepth d)
Helper for swapping adjacent dims at a given depth (see Shape.swapAdjacentAtDepth).
Instances For
def
Spec.Tensor.matMulBackwardSpec
{α : Type}
[Context α]
{m n p : ℕ}
(A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar)))
(B : Tensor α (Shape.dim n (Shape.dim p Shape.scalar)))
(dC : Tensor α (Shape.dim m (Shape.dim p Shape.scalar)))
:
Tensor α (Shape.dim m (Shape.dim n Shape.scalar)) × Tensor α (Shape.dim n (Shape.dim p Shape.scalar))
Backward pass for matrix multiplication: returns (dA, dB) given dC.
PyTorch analogy: if C = A @ B, then:
dA = dC @ BᵀdB = Aᵀ @ dC
Instances For
def
Spec.Tensor.bmmSpec
{α : Type}
[Add α]
[Mul α]
[Zero α]
{batch m n p : ℕ}
(A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar))))
(B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar))))
:
Tensor α (Shape.dim batch (Shape.dim m (Shape.dim p Shape.scalar)))
Batched matrix multiplication: [batch,m,n] × [batch,n,p] → [batch,m,p].
Instances For
def
Spec.Tensor.bmmBackwardSpec
{α : Type}
[Add α]
[Mul α]
[Zero α]
{batch m n p : ℕ}
(A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar))))
(B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar))))
(dC : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim p Shape.scalar))))
:
Backward pass for batched matrix multiplication.