TorchLean API

NN.Spec.Models.S4

Diagonal S4-style state-space layer #

This module provides TorchLean's diagonal recurrent SSM layer in the S4 family. It exposes the state-space recurrence used by S4-style models:

h_{t+1} = A h_t + B x_t, y_t = C h_{t+1} + D x_t.

The diagonal form is intentional: it shares the selective-scan core used by Mamba-style models, admits direct recurrence proofs, and can be connected to convolutional S4 kernels through a separate structured-kernel layer.

Reference: Gu, Goel, Ré. "Efficiently Modeling Long Sequences with Structured State Spaces", ICLR 2022.

structure Models.DiagonalS4Spec (α : Type) (inputDim stateDim outputDim : ) :

Parameters for a diagonal S4-style sequence layer.

Instances For
    def Models.DiagonalS4Spec.projectInput {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

    Project an input token into state channels.

    Instances For
      def Models.DiagonalS4Spec.projectOutput {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) :

      Project state channels to output channels.

      Instances For
        def Models.DiagonalS4Spec.step {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) :

        One recurrent S4-style token step, returning (new_state, output).

        Instances For
          def Models.DiagonalS4Spec.runList {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) :

          Run a list of tokens through the recurrent layer.

          Instances For
            @[simp]
            theorem Models.DiagonalS4Spec.runList_nil {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) :
            m.runList h0 [] = (h0, [])
            @[simp]
            theorem Models.DiagonalS4Spec.runList_cons {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (x : Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar)) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
            m.runList h0 (x :: xs) = match m.step h0 x with | (h1, y) => match m.runList h1 xs with | (hN, ys) => (hN, y :: ys)
            theorem Models.DiagonalS4Spec.runList_outputs_length {α : Type} [Add α] [Mul α] [Zero α] {inputDim stateDim outputDim : } (m : DiagonalS4Spec α inputDim stateDim outputDim) (h0 : Spec.Tensor α (Spec.Shape.dim stateDim Spec.Shape.scalar)) (xs : List (Spec.Tensor α (Spec.Shape.dim inputDim Spec.Shape.scalar))) :
            (m.runList h0 xs).2.length = xs.length

            A recurrent S4 pass emits one output token per input token.