TorchLean API

NN.Proofs.Autograd.Tape.Ops.Attention.ScaledDotProduct

ScaledDotProduct #

End-to-end fderiv/backprop correctness for a scaled dot-product attention graph, built out of the proven tape nodes (matmul, matrix_transpose, scale, softmax_last).

This is spec-level over . It is a corollary of the general graph theorem once each node used by the graph has a NodeFDerivCorrect instance.

PyTorch correspondence / citations #

@[reducible, inline]

Matrix shape for m×d Q/K/V inputs.

Instances For
    @[reducible, inline]

    Input context shapes: [Q, K, V], each m×d.

    Instances For

      Context index of Q in ΓQKV.

      Instances For

        Context index of K in ΓQKV.

        Instances For

          Context index of V in ΓQKV.

          Instances For

            Index of the most-recently appended tensor in a DGraph context.

            Instances For

              Scaled dot-product attention as a proved-correct DGraph.

              Computes Q K V ↦ softmax(c * (Q * Kᵀ)) * V and records intermediate values needed by backprop.

              Instances For

                Corollary of the general DAG theorem: backprop equals (fderiv eval)† for the attention graph.

                This is the formal statement that the tape reverse pass computes the VJP for the full attention computation.