Selective scan specs #
This file contains the small proof-facing core behind state-space sequence models such as S4 and Mamba.
The key observation, used by Mamba's hardware-aware parallel scan, is that each per-token recurrent update can be viewed as an affine map
h ↦ A_t h + b_t.
Affine maps compose associatively. A recurrent scan can therefore be implemented either by a
left-to-right recurrence or by a parallel prefix scan over affine summaries. The scalar definitions
below are intentionally compact so that NN/MLTheory/Proofs/StateSpace/Scan.lean can prove the algebra
without depending on a particular runtime backend. The diagonal tensor definitions are the direct
TorchLean spec analogue used by the model and CUDA contracts.
References:
- Gu, Goel, Ré. "Efficiently Modeling Long Sequences with Structured State Spaces" (S4), ICLR 2022.
- Gu, Dao. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces", COLM 2024.
- Dao, Gu. "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Mamba-2), ICML 2024.
A scalar affine transition h ↦ a*h + b.
- a : α
Linear multiplier. In diagonal SSMs this is one channel of the discretized state matrix.
- b : α
Additive input contribution for the current token.
Instances For
Instances For
Apply a scalar affine transition.
Instances For
Identity affine transition.
Instances For
Compose two affine transitions.
compose t₂ t₁ means "first apply t₁, then apply t₂".
Instances For
Sequentially run a list of scalar affine transitions from an initial state.
Instances For
Summarize a transition list as one affine transition.
This is the algebraic payload used by parallel selective scan: prefix summaries can be produced by
any associative scan algorithm, and applying the summary to h0 is equivalent to recurrence.
Instances For
Return every recurrent state after each scalar affine transition.
Instances For
A diagonal vector affine transition h ↦ a ⊙ h + b.
- a : Tensor α (Shape.dim stateDim Shape.scalar)
Elementwise recurrent multiplier.
- b : Tensor α (Shape.dim stateDim Shape.scalar)
Elementwise additive token contribution.
Instances For
Apply one diagonal affine state update.
Instances For
Compose diagonal affine transitions channelwise.
The order is the same as ScalarAffineTransition.compose: compose t₂ t₁ is first t₁, then t₂.
Instances For
Sequentially run diagonal transitions and return the final state.
Instances For
Return every hidden state from a diagonal selective scan.