Mamba-style selective state-space blocks #
Mamba replaces quadratic attention with a linear-time selective state-space recurrence. In full models, the token controls discretization and input/output state parameters.
This file exposes two layers:
MambaBlockSpec: a compact theorem-friendly diagonal SSM block kept for scan laws and kernel validation.SelectiveMambaBlockSpec: a fuller Mamba-style block with input/gate projections, causal depthwise convolution, SiLU, token-dependentDelta/B/C, diagonal selective scan, gated output, and output projection.
The compact block is intentionally retained: it is the smallest reusable core for proving scan algebra and for validating CUDA kernels. The full block builds the paper-style Mamba dataflow on top of the same affine-scan idea.
- recurrent selective scan (
h ← A ⊙ h + B ⊙ x_state), - a gated state readout,
- tokenwise input/output projections.
References:
- Gu, Dao. "Mamba: Linear-Time Sequence Modeling with Selective State Spaces", COLM 2024.
- Dao, Gu. "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" (Mamba-2), ICML 2024.
Parameters for a compact diagonal Mamba-style block.
- inProj : Spec.Tensor α (Spec.Shape.dim inputDim (Spec.Shape.dim stateDim Spec.Shape.scalar))
Input projection into SSM state channels.
- gateProj : Spec.Tensor α (Spec.Shape.dim inputDim (Spec.Shape.dim stateDim Spec.Shape.scalar))
Gate projection. The gate is
sigmoid(x @ gateProj). - outProj : Spec.Tensor α (Spec.Shape.dim stateDim (Spec.Shape.dim outputDim Spec.Shape.scalar))
Output projection from gated state channels.
- ssm : NN.Spec.Dynamics.DiagonalSSM α stateDim
Diagonal state-space core.
Instances For
Input-to-state projection.
Instances For
Token-dependent sigmoid gate.
Instances For
One Mamba-style token step, returning (new_state, output).
Instances For
Run a list of tokens through the recurrent block.
Instances For
A Mamba recurrent pass emits one output token per input token.
Parameters for a fuller Mamba-style selective SSM block.
Shape conventions:
inputDim: token/input feature width,innerDim: expanded channel width used by Mamba's convolution and SSM path,stateDim: per-channel diagonal SSM state size,outputDim: output feature width,convWidth: causal depthwise-convolution width.
The recurrence state has shape [innerDim, stateDim]. This mirrors the common implementation
view of Mamba where each expanded channel carries a small diagonal state vector.
- xProj : Spec.Tensor α (Spec.Shape.dim inputDim (Spec.Shape.dim innerDim Spec.Shape.scalar))
Content/input projection
x -> x_path. - zProj : Spec.Tensor α (Spec.Shape.dim inputDim (Spec.Shape.dim innerDim Spec.Shape.scalar))
Gate projection
x -> z_path. - convKernel : Spec.Tensor α (Spec.Shape.dim convWidth (Spec.Shape.dim innerDim Spec.Shape.scalar))
Causal depthwise-convolution kernel, indexed by
(tap, channel). - convBias : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)
Causal depthwise-convolution bias.
- dtProj : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim innerDim Spec.Shape.scalar))
Projection from activated convolution features to per-channel time steps
Delta. - dtBias : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)
Bias before the
softplustime-step nonlinearity. - A : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))
Positive diagonal state rates
A[d,n]used asexp(-Delta[d] * A[d,n]). - bProj : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))
Token-dependent input-state projection
B_t = u_t @ bProj. - cProj : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))
Token-dependent state-output projection
C_t = u_t @ cProj. - dSkip : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)
Per-channel residual/skip coefficient.
- outProj : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim outputDim Spec.Shape.scalar))
Output projection from expanded channels to output features.
Instances For
Content path projection.
Instances For
Gate path projection.
Instances For
SiLU/Swish applied channelwise.
Instances For
Causal depthwise convolution from a newest-first history of projected tokens.
history[0] is the current projected token, history[1] is the previous token, etc. Missing
history entries are treated as zero padding.
Instances For
Token-dependent positive time steps Delta = softplus(u @ dtProj + dtBias).
Instances For
Token-dependent input-state vector B_t.
Instances For
Token-dependent state-output vector C_t.
Instances For
One selective diagonal SSM update:
h'[d,n] = exp(-Delta[d] * A[d,n]) * h[d,n] + (Delta[d] * B_t[n]) * u[d].
Instances For
Read out expanded channels from the updated state using C_t, plus the Mamba skip path.
Instances For
One full Mamba token step from an already-updated convolution history.
The history argument is newest-first and must include the current projected content token.
Instances For
Internal recurrent runner carrying the causal convolution history.
Instances For
Run a sequence through the full selective Mamba block.
Instances For
The full Mamba recurrent pass emits one output token per input token.
The public full Mamba runner emits one output token per input token.