Shape #
Additional analytic (HasFDerivAt) tape nodes for shape permutations.
These nodes are linear/isometric and are useful for models that do explicit reshaping and dimension permutations (e.g. Multi-Head Attention head splitting/combining).
reshape is linear: on vectors it is just a type cast along Shape.size equality.
We implement it as a Node to keep the DAG theorem applicable.
reshape node: reinterpret the same underlying coordinates as a different shape.
This is only definable when Shape.size s₁ = Shape.size s₂; at the vector level it is a cast.
PyTorch analogue: view/reshape operations that do not change the total number of elements.
https://pytorch.org/docs/stable/tensor_view.html
Instances For
NodeFDerivCorrect for reshape (it is linear/isometric).
Instances For
flatten node: specialization of reshape to the canonical vector shape (.dim (Shape.size s) .scalar).
PyTorch analogue: flatten when applied to a contiguous tensor.
https://pytorch.org/docs/stable/generated/torch.flatten.html
Instances For
NodeFDerivCorrect for flatten.
Instances For
Move reindexVec across the left argument of an inner product.
Swap the first two axes of a 3D tensor shape: .dim m (.dim n rest) ↦ .dim n (.dim m rest).
This is implemented as a coordinate permutation (a linear isometry).
PyTorch analogue: transpose(0, 1) on a 3D tensor.
https://pytorch.org/docs/stable/generated/torch.transpose.html
Instances For
NodeFDerivCorrect for swap_first_two3d (linear coordinate permutation).
Instances For
Transpose the last two axes of a 3D tensor: .dim a (.dim b (.dim c .scalar)) ↦ .dim a (.dim c (.dim b .scalar)).
This is another coordinate permutation used in attention (switching K to Kᵀ while keeping
head/batch axes).
Instances For
NodeFDerivCorrect for transpose3d_last_two (linear coordinate permutation).