Matrix tape nodes #
Matrix multiplication, transpose, row/column broadcasting, and row means, with VJP correctness facts stated at the vectorized tape level.
Flattened size of an m×n matrix shape: Shape.size (.dim m (.dim n .scalar)) = m*n.
Instances For
Flattened size of a length-n vector shape: Shape.size (.dim n .scalar) = n.
Instances For
Vectorization commutes with matrix addition: toVecT (A + B) = toVecT A + toVecT B.
Matrix multiplication is developed at the vector level (flattened matrices) to integrate cleanly
with CtxVec and the HasFDerivAt machinery.
PyTorch analogue: torch.matmul / @ operator on 2D tensors.
https://pytorch.org/docs/stable/generated/torch.matmul.html
Helper: matSize m n is definitionally m * n.
Transpose on flattened matrices: (m×n) flattened row-major → (n×m) flattened row-major.
Instances For
Transpose is implemented as a coordinate permutation on flattened matrices.
PyTorch analogue: A.transpose(0, 1) for a 2D tensor.
https://pytorch.org/docs/stable/generated/torch.transpose.html
Tape node computing matrix transpose: (m×n) ↦ (n×m).
Instances For
NodeFDerivCorrect for matrix_transpose (it is linear/isometric).
Instances For
Matrix multiplication node on 2D tensors.
Instances For
NodeFDerivCorrect for the matrix-matrix multiplication node.
This packages the product rule and the dot/adjointness lemmas for Spec.mat_mul_spec.
Instances For
Broadcast a vector v : Vec m across the last axis to a flattened (m×n) matrix.
Instances For
Broadcast a vector v : Vec n across the first axis to a flattened (m×n) matrix.
Instances For
Row-wise sum: flattened (m×n) matrix → vector m.
Instances For
Row-wise mean: flattened (m×n) matrix → vector m.
Instances For
Broadcast a vector (.dim m .scalar) across columns to (.dim m (.dim n .scalar)).
Instances For
NodeFDerivCorrect for broadcast_row (linear op).
Instances For
Broadcast a vector (.dim n .scalar) across rows to (.dim m (.dim n .scalar)).
Instances For
NodeFDerivCorrect for broadcast_col (linear op).
Instances For
Shape-only nodes (reshape, flatten, and similar) live in NN.Proofs.Autograd.Tape.Nodes.Shape
(namespace TapeNodes.ShapeOps).
Row-wise mean (reduce last axis): (.dim m (.dim n .scalar)) → (.dim m .scalar).
Instances For
NodeFDerivCorrect for row_mean (reduce-mean along the last axis).