TorchLean API

NN.MLTheory.Proofs.StateSpace.MambaCausality

Causality of S4/Mamba-style recurrent blocks #

This file proves the sequence-causality property expected of state-space sequence models: appending future tokens cannot change outputs already emitted for a prefix.

We state the theorem at the list-runner level rather than for a particular CUDA kernel. Runtime implementations may use chunked or parallel selective scan, but they must refine these spec runners. Combined with NN.MLTheory.StateSpace.diagonalSelectiveScan_append, this gives the proof-facing contract for Mamba/S4-style causal sequence processing.

References:

theorem NN.MLTheory.StateSpace.diagonalS4_runList_append_outputs_prefix {α : Type} [Context α] {inputDim stateDim outputDim : } (m : Models.DiagonalS4Spec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (xs ys : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
List.take xs.length (m.runList h0 (xs ++ ys)).2 = (m.runList h0 xs).2

Diagonal S4 prefix causality.

Appending future tokens ys cannot change the outputs already produced for prefix xs.

theorem NN.MLTheory.StateSpace.compactMamba_runList_append_outputs_prefix {α : Type} [Context α] {inputDim stateDim outputDim : } (m : Models.MambaBlockSpec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (xs ys : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
List.take xs.length (m.runList h0 (xs ++ ys)).2 = (m.runList h0 xs).2

Compact Mamba prefix causality.

If a sequence xs has already been processed, appending future tokens ys cannot change the outputs for xs. This is the recurrent-model analogue of causal attention non-anticipation.

theorem NN.MLTheory.StateSpace.selectiveMamba_runListAux_append_outputs_prefix {α : Type} [Context α] {inputDim stateDim outputDim innerDim convWidth : } (m : Models.SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) (xs ys : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
List.take xs.length (m.runListAux h0 history (xs ++ ys)).2 = (m.runListAux h0 history xs).2

Full selective Mamba prefix causality for the internal runner.

The internal runner carries a newest-first causal convolution history. Even with that extra state, future input tokens only affect future outputs.

theorem NN.MLTheory.StateSpace.selectiveMamba_runList_append_outputs_prefix {α : Type} [Context α] {inputDim stateDim outputDim innerDim convWidth : } (m : Models.SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (xs ys : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
List.take xs.length (m.runList h0 (xs ++ ys)).2 = (m.runList h0 xs).2

Full selective Mamba prefix causality for the public runner.

This is the user-facing theorem: extending the input stream preserves all previously produced outputs.