FlashAttention Semantic Contract #
FlashAttention is an IO-aware implementation strategy for scaled dot-product attention: it tiles
the attention computation and maintains online softmax summaries so the full n × n attention
matrix does not need to be materialized. TorchLean models that idea in three layers:
- this file gives the proof-facing semantic contract for a fused FlashAttention operator;
NN/Runtime/Autograd/Engine/Cuda/Kernels.leanexposes native CUDA/stub FFI kernels for the runtime path;- the CUDA FFI boundary is documented separately because Lean does not verify CUDA machine code.
The key point is that the fused op has the same denotation as standard masked scaled dot-product attention over the spec scalar. Different tile sizes are runtime scheduling choices, not semantic choices.
What is proved here? #
The theorems in this file are deliberately small but important:
onlineSoftmaxTiledAttention_eq_scaledDotProductAttentionproves the named FlashAttention algorithmic contract has the same denotation as standard attention.flashAttention_eq_scaledDotProductAttentionproves the fused forward operator is semantically equal to TorchLean's existing standard attention spec.flashAttentionBackward_eq_scaledDotProductAttentionBackwardproves the fused VJP contract is semantically equal to the existing standard attention backward spec.cudaLoopFlashAttention_eq_onlineSoftmaxTiledAttentiongives a Lean functional model of the native kernel loops and proves that this loop model denotes the same online/tiled contract.
These are definitional-equality theorems because the proof-facing contract spells out the same mathematical stages as standard attention. The native CUDA implementation is tested against this contract operationally and remains a runtime trust boundary, like the other CUDA kernels.
Why this is not a CUDA proof #
The definitions below are the mathematical contract for FlashAttention. They do not claim to verify the native CUDA source. Instead, they make the important theorem explicit:
onlineSoftmaxTiledAttention cfg ctx = scaledDotProductAttention ctx.
That is the theorem a compiler rewrite or fused backend relies on. A production IO-tiled CUDA kernel can be swapped in under the same contract once it is tested/refined.
References:
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré, "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness", arXiv:2205.14135.
- Tri Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning", arXiv:2307.08691.
- Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao, "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision", arXiv:2407.08608.
Runtime tiling metadata for a FlashAttention-style fused implementation.
The spec-level denotation below intentionally ignores these fields: they describe how an implementation schedules work, not what mathematical function the operator computes.
- blockQ : ℕ
Query block size used by a tiled implementation.
0means "backend default". - blockK : ℕ
Key/value block size used by a tiled implementation.
0means "backend default".
Instances For
Instances For
Instances For
Backend-default tiling.
Instances For
Algorithmic Contract #
The original FlashAttention algorithm streams over blocks of keys/values and maintains a row-wise online softmax summary. The exact CUDA schedule is an implementation detail, but the mathematical result is the same as the closed-form stabilized softmax over the full masked score row.
TorchLean names the stages below so proofs and compiler passes can point at a real algorithmic contract rather than only at an opaque fused primitive:
- build scaled scores
QKᵀ / sqrt(d); - apply the boolean mask with true hard-mask semantics (blocked numerator is zero);
- compute the same row-wise normalized weights that an online summary converges to;
- multiply by values.
This is intentionally schedule-polymorphic: cfg.blockQ and cfg.blockK describe how a runtime may
tile the work, but they do not alter the denotation.
Unmasked attention scores QKᵀ.
Instances For
Scaled attention scores before row normalization.
Masking is applied at the weight level by onlineSoftmaxWeights, using the same true hard-mask
semantics as scaledDotProductAttention.
Instances For
The row-wise softmax weights produced by the online softmax summary.
Activation.softmaxSpec already uses the stabilized form exp(x - rowMax) / Σ exp(x - rowMax).
This definition is the denotation that a FlashAttention implementation must refine. It is not a
formal model of Dao-style tile loops or SRAM/HBM traffic.
Instances For
Proof-facing FlashAttention forward algorithm.
This is the mathematical result of the online/tiled schedule: row-wise softmax weights multiplied by
V. Runtime kernels may avoid storing the full weights, but they must refine this value.
Instances For
CUDA Denotation #
The native runtime kernel is intended to compute the same row program in a fused way:
- for each
(batch, head, query)row, scan keys to build masked/scaled scores; - compute the stabilized softmax normalization for that row;
- accumulate
Σ_j softmax(score_j) * V_jdirectly into the output.
cudaLoopFlashAttention is intentionally a denotational target, written with tensor
combinators rather than CUDA thread/block syntax. The equalities below are definitional sanity
checks: they say the named fused operator denotes standard SDPA in the spec. They do not verify the
CUDA source code, the online-softmax recurrence, or the memory-IO schedule. Those remain explicit
runtime/FFI contracts tested against this target.
Denotational target for the fused CUDA FlashAttention forward kernel.
Instances For
The CUDA denotational target has the same spec meaning as standard SDPA.
This is intentionally a definitional theorem, not a proof about CUDA machine code or an online softmax tile recurrence.
The proof-facing FlashAttention denotation equals standard SDPA.
This theorem is useful for graph-rewrite semantics, but should not be read as a verification of a particular CUDA implementation.
Semantic FlashAttention forward operator.
At the spec layer this is the named online/tiled algorithmic contract above. Runtime implementations may use tiling, online softmax summaries, or a fused CUDA kernel, but they must refine this denotation to be considered correct.
Instances For
Semantic FlashAttention backward/VJP operator.
This is the local derivative contract a fused backward kernel should refine. The actual CUDA kernel may recompute attention probabilities from row statistics rather than storing the full attention matrix, but the returned adjoints must match this spec-level VJP up to the chosen floating-point error envelope.
Instances For
Forward semantic correctness of the fused FlashAttention spec.
Backward/VJP semantic correctness of the fused FlashAttention spec.