TorchLean API

NN.Spec.Layers.FlashAttention

FlashAttention Semantic Contract #

FlashAttention is an IO-aware implementation strategy for scaled dot-product attention: it tiles the attention computation and maintains online softmax summaries so the full n × n attention matrix does not need to be materialized. TorchLean models that idea in three layers:

The key point is that the fused op has the same denotation as standard masked scaled dot-product attention over the spec scalar. Different tile sizes are runtime scheduling choices, not semantic choices.

What is proved here? #

The theorems in this file are deliberately small but important:

These are definitional-equality theorems because the proof-facing contract spells out the same mathematical stages as standard attention. The native CUDA implementation is tested against this contract operationally and remains a runtime trust boundary, like the other CUDA kernels.

Why this is not a CUDA proof #

The definitions below are the mathematical contract for FlashAttention. They do not claim to verify the native CUDA source. Instead, they make the important theorem explicit:

onlineSoftmaxTiledAttention cfg ctx = scaledDotProductAttention ctx.

That is the theorem a compiler rewrite or fused backend relies on. A production IO-tiled CUDA kernel can be swapped in under the same contract once it is tested/refined.

References:

Runtime tiling metadata for a FlashAttention-style fused implementation.

The spec-level denotation below intentionally ignores these fields: they describe how an implementation schedules work, not what mathematical function the operator computes.

  • blockQ :

    Query block size used by a tiled implementation. 0 means "backend default".

  • blockK :

    Key/value block size used by a tiled implementation. 0 means "backend default".

Instances For

    Backend-default tiling.

    Instances For

      Algorithmic Contract #

      The original FlashAttention algorithm streams over blocks of keys/values and maintains a row-wise online softmax summary. The exact CUDA schedule is an implementation detail, but the mathematical result is the same as the closed-form stabilized softmax over the full masked score row.

      TorchLean names the stages below so proofs and compiler passes can point at a real algorithmic contract rather than only at an opaque fused primitive:

      1. build scaled scores QKᵀ / sqrt(d);
      2. apply the boolean mask with true hard-mask semantics (blocked numerator is zero);
      3. compute the same row-wise normalized weights that an online summary converges to;
      4. multiply by values.

      This is intentionally schedule-polymorphic: cfg.blockQ and cfg.blockK describe how a runtime may tile the work, but they do not alter the denotation.

      def Spec.attentionScores {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

      Unmasked attention scores QKᵀ.

      Instances For
        def Spec.scaledAttentionScores {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

        Scaled attention scores before row normalization.

        Masking is applied at the weight level by onlineSoftmaxWeights, using the same true hard-mask semantics as scaledDotProductAttention.

        Instances For
          def Spec.onlineSoftmaxWeights {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

          The row-wise softmax weights produced by the online softmax summary.

          Activation.softmaxSpec already uses the stabilized form exp(x - rowMax) / Σ exp(x - rowMax). This definition is the denotation that a FlashAttention implementation must refine. It is not a formal model of Dao-style tile loops or SRAM/HBM traffic.

          Instances For
            def Spec.onlineSoftmaxTiledAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

            Proof-facing FlashAttention forward algorithm.

            This is the mathematical result of the online/tiled schedule: row-wise softmax weights multiplied by V. Runtime kernels may avoid storing the full weights, but they must refine this value.

            Instances For

              CUDA Denotation #

              The native runtime kernel is intended to compute the same row program in a fused way:

              1. for each (batch, head, query) row, scan keys to build masked/scaled scores;
              2. compute the stabilized softmax normalization for that row;
              3. accumulate Σ_j softmax(score_j) * V_j directly into the output.

              cudaLoopFlashAttention is intentionally a denotational target, written with tensor combinators rather than CUDA thread/block syntax. The equalities below are definitional sanity checks: they say the named fused operator denotes standard SDPA in the spec. They do not verify the CUDA source code, the online-softmax recurrence, or the memory-IO schedule. Those remain explicit runtime/FFI contracts tested against this target.

              def Spec.cudaLoopFlashAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

              Denotational target for the fused CUDA FlashAttention forward kernel.

              Instances For
                @[simp]
                theorem Spec.cudaLoopFlashAttention_eq_onlineSoftmaxTiledAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :
                @[simp]
                theorem Spec.cudaLoopFlashAttention_eq_scaledDotProductAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                The CUDA denotational target has the same spec meaning as standard SDPA.

                This is intentionally a definitional theorem, not a proof about CUDA machine code or an online softmax tile recurrence.

                @[simp]
                theorem Spec.onlineSoftmaxTiledAttention_eq_scaledDotProductAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                The proof-facing FlashAttention denotation equals standard SDPA.

                This theorem is useful for graph-rewrite semantics, but should not be read as a verification of a particular CUDA implementation.

                def Spec.flashAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                Semantic FlashAttention forward operator.

                At the spec layer this is the named online/tiled algorithmic contract above. Runtime implementations may use tiling, online softmax summaries, or a fused CUDA kernel, but they must refine this denotation to be considered correct.

                Instances For
                  def Spec.flashAttentionBackward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) (dOut : Tensor α (Shape.dim nQ (Shape.dim dModel Shape.scalar))) :

                  Semantic FlashAttention backward/VJP operator.

                  This is the local derivative contract a fused backward kernel should refine. The actual CUDA kernel may recompute attention probabilities from row statistics rather than storing the full attention matrix, but the returned adjoints must match this spec-level VJP up to the chosen floating-point error envelope.

                  Instances For
                    @[simp]
                    theorem Spec.flashAttention_eq_scaledDotProductAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) :

                    Forward semantic correctness of the fused FlashAttention spec.

                    @[simp]
                    theorem Spec.flashAttentionBackward_eq_scaledDotProductAttentionBackward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (cfg : FlashAttentionConfig) {nQ nK dModel : } {h1 : nQ 0} {h2 : nK 0} (ctx : AttentionContext α nQ nK dModel h1 h2) (dOut : Tensor α (Shape.dim nQ (Shape.dim dModel Shape.scalar))) :

                    Backward/VJP semantic correctness of the fused FlashAttention spec.