TorchLean API

NN.Spec.Layers.SelectiveScan

Selective scan specs #

This file contains the small proof-facing core behind state-space sequence models such as S4 and Mamba.

The key observation, used by Mamba's hardware-aware parallel scan, is that each per-token recurrent update can be viewed as an affine map

h ↦ A_t h + b_t.

Affine maps compose associatively. A recurrent scan can therefore be implemented either by a left-to-right recurrence or by a parallel prefix scan over affine summaries. The scalar definitions below are intentionally compact so that NN/MLTheory/Proofs/StateSpace/Scan.lean can prove the algebra without depending on a particular runtime backend. The diagonal tensor definitions are the direct TorchLean spec analogue used by the model and CUDA contracts.

References:

A scalar affine transition h ↦ a*h + b.

  • a : α

    Linear multiplier. In diagonal SSMs this is one channel of the discretized state matrix.

  • b : α

    Additive input contribution for the current token.

Instances For
    @[implicit_reducible]
    def Spec.ScalarAffineTransition.apply {α : Type} [Mul α] [Add α] (tr : ScalarAffineTransition α) (h : α) :
    α

    Apply a scalar affine transition.

    Instances For

      Identity affine transition.

      Instances For

        Compose two affine transitions.

        compose t₂ t₁ means "first apply t₁, then apply t₂".

        Instances For
          def Spec.runScalarAffine {α : Type} [Mul α] [Add α] (h0 : α) :

          Sequentially run a list of scalar affine transitions from an initial state.

          Instances For

            Summarize a transition list as one affine transition.

            This is the algebraic payload used by parallel selective scan: prefix summaries can be produced by any associative scan algorithm, and applying the summary to h0 is equivalent to recurrence.

            Instances For
              def Spec.scalarAffineScan {α : Type} [Mul α] [Add α] (h0 : α) :

              Return every recurrent state after each scalar affine transition.

              Instances For
                structure Spec.DiagonalTransition (α : Type) (stateDim : ) :

                A diagonal vector affine transition h ↦ a ⊙ h + b.

                Instances For
                  def Spec.DiagonalTransition.apply {α : Type} [Add α] [Mul α] {stateDim : } (tr : DiagonalTransition α stateDim) (h : Tensor α (Shape.dim stateDim Shape.scalar)) :

                  Apply one diagonal affine state update.

                  Instances For
                    def Spec.DiagonalTransition.compose {α : Type} [Add α] [Mul α] {stateDim : } (t₂ t₁ : DiagonalTransition α stateDim) :
                    DiagonalTransition α stateDim

                    Compose diagonal affine transitions channelwise.

                    The order is the same as ScalarAffineTransition.compose: compose t₂ t₁ is first t₁, then t₂.

                    Instances For
                      def Spec.runDiagonalTransitions {α : Type} [Add α] [Mul α] {stateDim : } (h0 : Tensor α (Shape.dim stateDim Shape.scalar)) :
                      List (DiagonalTransition α stateDim)Tensor α (Shape.dim stateDim Shape.scalar)

                      Sequentially run diagonal transitions and return the final state.

                      Instances For
                        def Spec.diagonalSelectiveScan {α : Type} [Add α] [Mul α] {stateDim : } (h0 : Tensor α (Shape.dim stateDim Shape.scalar)) :
                        List (DiagonalTransition α stateDim)List (Tensor α (Shape.dim stateDim Shape.scalar))

                        Return every hidden state from a diagonal selective scan.

                        Instances For