TorchLean API

NN.Spec.Module.Rnn

RNN/LSTM/GRU module wrappers #

The layer specs (NN/Spec/Layers/Rnn.lean, lstm.lean, gru.lean) expose step-level and sequence-level recurrence definitions.

This file wraps the "sequence forward" functions as NNModuleSpecs so recurrent blocks can be composed with other modules in a SpecChain.

Design choices:

If you think in PyTorch: these are the nn.RNN/nn.LSTM/nn.GRU "return the full output sequence" wrappers, with the initial hidden/state fixed to zeros.

def Spec.RNNModuleSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (rnn : RNNSpec α inputSize hiddenSize) :

RNN sequence wrapper with a zero initial hidden state.

Instances For
    def Spec.LSTMModuleSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) :

    LSTM sequence wrapper with a zero initial state; returns the output sequence.

    Instances For
      def Spec.GRUModuleSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) :

      GRU sequence wrapper with a zero initial hidden state; returns the output sequence.

      Instances For
        def Spec.BiLSTMModuleSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (forward_lstm backward_lstm : LSTMSpec α inputSize hiddenSize) :
        ModSpec.NNModuleSpec α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar)) (Shape.dim seqLen (Shape.dim (hiddenSize + hiddenSize) Shape.scalar))

        Bidirectional LSTM wrapper (concatenates forward/backward features).

        Instances For
          def Spec.RNNCellModuleSpec {α : Type} [Context α] {inputSize hiddenSize : } (rnn : RNNSpec α inputSize hiddenSize) :
          ModSpec.NNModuleSpec α (Shape.dim (inputSize + hiddenSize) Shape.scalar) (Shape.dim hiddenSize Shape.scalar)

          Wrap rnn_cell_spec as an NNModuleSpec for a single timestep.

          Input convention: we take a single vector [x; h] (concatenated input and previous hidden state), so the module is shape-safe and easy to compose.

          Instances For
            def Spec.LSTMCellModuleSpec {α : Type} [Context α] {inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) :
            ModSpec.NNModuleSpec α (Shape.dim (inputSize + hiddenSize + hiddenSize) Shape.scalar) (Shape.dim (hiddenSize + hiddenSize) Shape.scalar)

            Wrap lstm_cell_spec as an NNModuleSpec for a single timestep.

            Input convention: a single concatenated vector [x; h; c] (input, previous hidden, previous cell). Output convention: the concatenated new state [h'; c'].

            Instances For
              def Spec.GRUCellModuleSpec {α : Type} [Context α] {inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) :
              ModSpec.NNModuleSpec α (Shape.dim (inputSize + hiddenSize) Shape.scalar) (Shape.dim hiddenSize Shape.scalar)

              Wrap gru_cell_spec as an NNModuleSpec for a single timestep, using input [x; h].

              Instances For
                def Spec.BiRNNModuleSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (forward_rnn backward_rnn : RNNSpec α inputSize hiddenSize) :
                ModSpec.NNModuleSpec α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar)) (Shape.dim seqLen (Shape.dim (hiddenSize + hiddenSize) Shape.scalar))

                Bidirectional RNN wrapper (concatenates forward/backward features).

                We run the RNNSpec forward over x, run it again over the reversed sequence, then reverse outputs back and concatenate along the feature axis.

                Instances For