TorchLean API

NN.Spec.Core.TensorReductionShape.LinearAlgebra

Linear Algebra Helpers #

Matrix transpose, 3D transposes, matmul/bmm backward specs, and shape matching.

Transpose a matrix (m×n) into (n×m).

PyTorch analogy: A.transpose(0, 1) or A.T for 2D tensors.

Instances For

    Permute a 3D tensor from (a,b,c) to (b,c,a).

    Instances For

      Permute a 3D tensor from (a,b,c) to (c,a,b).

      Instances For

        Swap the last two axes of a 3D tensor: (a,b,c) to (a,c,b).

        Instances For
          def Spec.Tensor.swapFirstTwoSpec {α : Type} {m n : } {s : Shape} (t : Tensor α (Shape.dim m (Shape.dim n s))) :

          Swap the first two dimensions of a tensor (m,n,...) to (n,m,...).

          Instances For
            def Spec.Tensor.swapAtDepthHelper {β : Type} {shape : Shape} (tensor : Tensor β shape) (d : ) :

            Helper for swapping adjacent dims at a given depth (see Shape.swapAdjacentAtDepth).

            Instances For
              def Spec.Tensor.swapAtDepthSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (depth : ) :

              Swap adjacent dimensions at a given depth inside a leading batch dimension.

              Instances For

                Backward pass for matrix multiplication: returns (dA, dB) given dC.

                PyTorch analogy: if C = A @ B, then:

                • dA = dC @ Bᵀ
                • dB = Aᵀ @ dC
                Instances For
                  def Spec.Tensor.bmmSpec {α : Type} [Add α] [Mul α] [Zero α] {batch m n p : } (A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar)))) (B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar)))) :

                  Batched matrix multiplication: [batch,m,n] × [batch,n,p] → [batch,m,p].

                  Instances For
                    def Spec.Tensor.bmmBackwardSpec {α : Type} [Add α] [Mul α] [Zero α] {batch m n p : } (A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar)))) (B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar)))) (dC : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim p Shape.scalar)))) :

                    Backward pass for batched matrix multiplication.

                    Instances For