TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Batched

Batched #

Additional HasFDerivAt-level nodes for batched (3D) ops.

These are useful for MultiHeadAttention graphs where the head dimension is explicit:

All results here are spec-level over .

noncomputable def Proofs.Autograd.TapeNodes.Batched.heads {h n : } (x : Vec (h * n)) :
Fin hVec n

Split a flattened h * n vector into h “heads” of length n.

This is the vector-level analogue of reshaping (..., h*n) into (..., h, n). It is used to define batched operations as head-wise operations.

Instances For
    noncomputable def Proofs.Autograd.TapeNodes.Batched.unheads {h n : } (r : Fin hVec n) :
    Vec (h * n)

    Inverse of heads: concatenate head vectors back into one flattened vector.

    Instances For
      noncomputable def Proofs.Autograd.TapeNodes.Batched.headsCLM {h n : } :
      Vec (h * n) →L[] Fin hVec n

      Continuous linear map version of heads.

      Instances For
        noncomputable def Proofs.Autograd.TapeNodes.Batched.unheadsCLM {h n : } :
        (Fin hVec n) →L[] Vec (h * n)

        Continuous linear map version of unheads.

        Instances For
          @[reducible, inline]

          Flattened size of h many m×n matrices (row-major): h * (m*n).

          Instances For

            Bilinear map for batched matmul, packaged as A →L (B →L A ⬝ B) in flattened form.

            Instances For

              Batched matmul node (head-wise): (h×m×n) × (h×n×p) → (h×m×p).

              PyTorch analogue: torch.matmul with leading batch dimension h. https://pytorch.org/docs/stable/generated/torch.matmul.html

              Instances For

                NodeFDerivCorrect for the batched matmul node.

                Instances For

                  Batched row-wise softmax node: apply softmax_last independently per head.

                  Shape: h × (m×n) → h × (m×n), where each head contains an m×n matrix and softmax is along the last axis (size n) within each row.

                  PyTorch analogue: torch.nn.functional.softmax(x, dim=-1) with a leading batch dimension. https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html

                  Instances For

                    NodeFDerivCorrect for softmax_last in the batched/head-wise setting.

                    Instances For