TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Matrix

Matrix tape nodes #

Matrix multiplication, transpose, row/column broadcasting, and row means, with VJP correctness facts stated at the vectorized tape level.

@[reducible, inline]

Flattened size of an m×n matrix shape: Shape.size (.dim m (.dim n .scalar)) = m*n.

Instances For
    @[reducible, inline]

    Flattened size of a length-n vector shape: Shape.size (.dim n .scalar) = n.

    Instances For
      def Proofs.Autograd.TapeNodes.Matmul.idxMN {m n : } (i : Fin m) (j : Fin n) :
      Fin (matSize m n)

      Convert (i,j) coordinates into a flattened index for an m×n matrix vectorization.

      Instances For

        Vectorization commutes with matrix addition: toVecT (A + B) = toVecT A + toVecT B.

        noncomputable def Proofs.Autograd.TapeNodes.Matmul.matmulVec {m n p : } (a : Vec (matSize m n)) (b : Vec (matSize n p)) :
        Vec (matSize m p)

        A bilinear map on flattened matrices: (m×n) × (n×p) → (m×p) on Vec (Shape.size ...).

        Instances For
          @[simp]
          theorem Proofs.Autograd.TapeNodes.Matmul.matmulVec_apply {m n p : } (a : Vec (matSize m n)) (b : Vec (matSize n p)) (ip : Fin (matSize m p)) :
          (matmulVec a b).ofLp ip = have hp := ; have i := ip.divNat; have k' := ip.modNat; have k := Fin.cast hp k'; j : Fin n, a.ofLp (idxMN i j) * b.ofLp (idxMN j k)

          Matrix multiplication is developed at the vector level (flattened matrices) to integrate cleanly with CtxVec and the HasFDerivAt machinery.

          PyTorch analogue: torch.matmul / @ operator on 2D tensors. https://pytorch.org/docs/stable/generated/torch.matmul.html

          noncomputable def Proofs.Autograd.TapeNodes.Matmul.matmulCLMRight {m n p : } (a : Vec (matSize m n)) :

          For fixed left operand a, matmulCLMRight a is the linear map b ↦ a*b.

          Instances For

            Continuous bilinear map for matrix multiplication on flattened vectors.

            Instances For
              @[simp]

              Spec.mat_mul_spec agrees with matmulVec after flattening both inputs/outputs.

              Helper: matSize m n is definitionally m * n.

              Equivalence implementing matrix transpose on flattened indices.

              Instances For

                Transpose on flattened matrices: (m×n) flattened row-major → (n×m) flattened row-major.

                Instances For

                  Transpose is implemented as a coordinate permutation on flattened matrices.

                  PyTorch analogue: A.transpose(0, 1) for a 2D tensor. https://pytorch.org/docs/stable/generated/torch.transpose.html

                  Tape node computing matrix transpose: (m×n) ↦ (n×m).

                  Instances For

                    NodeFDerivCorrect for matrix_transpose (it is linear/isometric).

                    Instances For

                      Matrix multiplication node on 2D tensors.

                      Instances For

                        NodeFDerivCorrect for the matrix-matrix multiplication node.

                        This packages the product rule and the dot/adjointness lemmas for Spec.mat_mul_spec.

                        Instances For

                          Broadcast a vector v : Vec m across the last axis to a flattened (m×n) matrix.

                          Instances For

                            Broadcast a vector v : Vec n across the first axis to a flattened (m×n) matrix.

                            Instances For

                              Row-wise sum: flattened (m×n) matrix → vector m.

                              Instances For

                                Row-wise mean: flattened (m×n) matrix → vector m.

                                Instances For

                                  Broadcast a vector (.dim m .scalar) across columns to (.dim m (.dim n .scalar)).

                                  Instances For

                                    NodeFDerivCorrect for broadcast_row (linear op).

                                    Instances For

                                      Broadcast a vector (.dim n .scalar) across rows to (.dim m (.dim n .scalar)).

                                      Instances For

                                        NodeFDerivCorrect for broadcast_col (linear op).

                                        Instances For

                                          Shape-only nodes (reshape, flatten, and similar) live in NN.Proofs.Autograd.Tape.Nodes.Shape (namespace TapeNodes.ShapeOps).

                                          Row-wise mean (reduce last axis): (.dim m (.dim n .scalar)) → (.dim m .scalar).

                                          Instances For

                                            NodeFDerivCorrect for row_mean (reduce-mean along the last axis).

                                            Instances For