Batched #
Additional HasFDerivAt-level nodes for batched (3D) ops.
These are useful for MultiHeadAttention graphs where the head dimension is explicit:
- batched matrix multiplication:
(h×m×n) × (h×n×p) → (h×m×p) - batched (row-wise) softmax:
h × (m×n) → h × (m×n)
All results here are spec-level over ℝ.
Split a flattened h * n vector into h “heads” of length n.
This is the vector-level analogue of reshaping (..., h*n) into (..., h, n).
It is used to define batched operations as head-wise operations.
Instances For
Flattened size of h many m×n matrices (row-major): h * (m*n).
Instances For
Batched matmul node (head-wise): (h×m×n) × (h×n×p) → (h×m×p).
PyTorch analogue: torch.matmul with leading batch dimension h.
https://pytorch.org/docs/stable/generated/torch.matmul.html
Instances For
NodeFDerivCorrect for the batched matmul node.
Instances For
Batched row-wise softmax node: apply softmax_last independently per head.
Shape: h × (m×n) → h × (m×n), where each head contains an m×n matrix and softmax is along the
last axis (size n) within each row.
PyTorch analogue: torch.nn.functional.softmax(x, dim=-1) with a leading batch dimension.
https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html
Instances For
NodeFDerivCorrect for softmax_last in the batched/head-wise setting.