TorchLean API

NN.Spec.Models.Mamba

Mamba-style selective state-space blocks #

Mamba replaces quadratic attention with a linear-time selective state-space recurrence. In full models, the token controls discretization and input/output state parameters.

This file exposes two layers:

The compact block is intentionally retained: it is the smallest reusable core for proving scan algebra and for validating CUDA kernels. The full block builds the paper-style Mamba dataflow on top of the same affine-scan idea.

References:

structure Models.MambaBlockSpec (α : Type) (inputDim stateDim outputDim : ) :

Parameters for a compact diagonal Mamba-style block.

Instances For
    def Models.MambaBlockSpec.projectInput {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

    Input-to-state projection.

    Instances For
      def Models.MambaBlockSpec.gate {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

      Token-dependent sigmoid gate.

      Instances For
        def Models.MambaBlockSpec.step {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (h : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

        One Mamba-style token step, returning (new_state, output).

        Instances For
          def Models.MambaBlockSpec.runList {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) :

          Run a list of tokens through the recurrent block.

          Instances For
            @[simp]
            theorem Models.MambaBlockSpec.runList_nil {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) :
            m.runList h0 [] = (h0, [])
            @[simp]
            theorem Models.MambaBlockSpec.runList_cons {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
            m.runList h0 (x :: xs) = match m.step h0 x with | (h1, y) => match m.runList h1 xs with | (hN, ys) => (hN, y :: ys)
            theorem Models.MambaBlockSpec.runList_outputs_length {α : Type} [Context α] {inputDim stateDim outputDim : } (m : MambaBlockSpec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
            (m.runList h0 xs).2.length = xs.length

            A Mamba recurrent pass emits one output token per input token.

            structure Models.SelectiveMambaBlockSpec (α : Type) (inputDim innerDim stateDim outputDim convWidth : ) :

            Parameters for a fuller Mamba-style selective SSM block.

            Shape conventions:

            • inputDim: token/input feature width,
            • innerDim: expanded channel width used by Mamba's convolution and SSM path,
            • stateDim: per-channel diagonal SSM state size,
            • outputDim: output feature width,
            • convWidth: causal depthwise-convolution width.

            The recurrence state has shape [innerDim, stateDim]. This mirrors the common implementation view of Mamba where each expanded channel carries a small diagonal state vector.

            Instances For
              def Models.SelectiveMambaBlockSpec.projectX {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

              Content path projection.

              Instances For
                def Models.SelectiveMambaBlockSpec.projectZ {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

                Gate path projection.

                Instances For

                  SiLU/Swish applied channelwise.

                  Instances For
                    def Models.SelectiveMambaBlockSpec.causalDepthwiseConv {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) :

                    Causal depthwise convolution from a newest-first history of projected tokens.

                    history[0] is the current projected token, history[1] is the previous token, etc. Missing history entries are treated as zero padding.

                    Instances For
                      def Models.SelectiveMambaBlockSpec.delta {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (u : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                      Token-dependent positive time steps Delta = softplus(u @ dtProj + dtBias).

                      Instances For
                        def Models.SelectiveMambaBlockSpec.bToken {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (u : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                        Token-dependent input-state vector B_t.

                        Instances For
                          def Models.SelectiveMambaBlockSpec.cToken {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (u : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                          Token-dependent state-output vector C_t.

                          Instances For
                            def Models.SelectiveMambaBlockSpec.selectiveStateStep {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (u : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                            One selective diagonal SSM update:

                            h'[d,n] = exp(-Delta[d] * A[d,n]) * h[d,n] + (Delta[d] * B_t[n]) * u[d].

                            Instances For
                              def Models.SelectiveMambaBlockSpec.stateReadout {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (u : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                              Read out expanded channels from the updated state using C_t, plus the Mamba skip path.

                              Instances For
                                def Models.SelectiveMambaBlockSpec.stepWithHistory {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) (z : Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar)) :

                                One full Mamba token step from an already-updated convolution history.

                                The history argument is newest-first and must include the current projected content token.

                                Instances For
                                  def Models.SelectiveMambaBlockSpec.runListAux {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) :

                                  Internal recurrent runner carrying the causal convolution history.

                                  Instances For
                                    def Models.SelectiveMambaBlockSpec.runList {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) :

                                    Run a sequence through the full selective Mamba block.

                                    Instances For
                                      @[simp]
                                      theorem Models.SelectiveMambaBlockSpec.runListAux_nil {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) :
                                      m.runListAux h0 history [] = (h0, [])
                                      @[simp]
                                      theorem Models.SelectiveMambaBlockSpec.runList_nil {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) :
                                      m.runList h0 [] = (h0, [])
                                      theorem Models.SelectiveMambaBlockSpec.runListAux_outputs_length {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (history : List (Spec.Tensor α (Spec.Shape.dim innerDim Spec.Shape.scalar))) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
                                      (m.runListAux h0 history xs).2.length = xs.length

                                      The full Mamba recurrent pass emits one output token per input token.

                                      theorem Models.SelectiveMambaBlockSpec.runList_outputs_length {α : Type} [Context α] {inputDim innerDim stateDim outputDim convWidth : } (m : SelectiveMambaBlockSpec α inputDim innerDim stateDim outputDim convWidth) (h0 : Spec.Tensor α (Spec.Shape.dim innerDim (Spec.Shape.dim stateDim Spec.Shape.scalar))) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
                                      (m.runList h0 xs).2.length = xs.length

                                      The public full Mamba runner emits one output token per input token.