TorchLean API

NN.Spec.Layers.Lstm

LSTM (spec layer) #

TorchLean provides a small LSTM specification that is:

References (math + PyTorch behavior) #

Notes on parameterization #

Many libraries expose two matrices per gate (W_ih and W_hh) and add them. In this spec we use a single matrix applied to a concatenated vector [x_t; h_{t-1}]. It's the same computation, just packaged to reuse TorchLean's tensor building blocks.

structure Spec.LSTMSpec (α : Type) (inputSize hiddenSize : ) :

Parameters for an LSTM cell, with one (hiddenSize × (inputSize + hiddenSize)) matrix per gate.

This corresponds to the usual (W_ih, W_hh) parameterization in libraries like PyTorch, but we package it as a single matrix applied to [x_t; h_{t-1}] to reuse TorchLean's tensor building blocks.

  • forget_weights : WeightMatrix α hiddenSize (inputSize + hiddenSize)

    Forget-gate weights for f_t = sigmoid(W_f [x_t; h_{t-1}] + b_f).

  • forget_bias : HiddenVector α hiddenSize

    Forget-gate bias.

  • input_weights : WeightMatrix α hiddenSize (inputSize + hiddenSize)

    Input-gate weights for i_t = sigmoid(W_i [x_t; h_{t-1}] + b_i).

  • input_bias : HiddenVector α hiddenSize

    Input-gate bias.

  • candidate_weights : WeightMatrix α hiddenSize (inputSize + hiddenSize)

    Candidate/cell-proposal weights for g_t = tanh(W_g [x_t; h_{t-1}] + b_g).

  • candidate_bias : HiddenVector α hiddenSize

    Candidate/cell-proposal bias.

  • output_weights : WeightMatrix α hiddenSize (inputSize + hiddenSize)

    Output-gate weights for o_t = sigmoid(W_o [x_t; h_{t-1}] + b_o).

  • output_bias : HiddenVector α hiddenSize

    Output-gate bias.

Instances For
    structure Spec.LSTMState (α : Type) (hiddenSize : ) :

    LSTM recurrent state: hidden vector h_t and cell vector c_t.

    • hidden : HiddenVector α hiddenSize

      Exposed hidden state h_t.

    • cell : HiddenVector α hiddenSize

      Internal memory/cell state c_t.

    Instances For
      def Spec.lstmCellSpec {α : Type} [Context α] {inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (input : InputVector α inputSize) (prev_state : LSTMState α hiddenSize) :
      LSTMState α hiddenSize

      One LSTM cell step: update (h_{t-1}, c_{t-1}) given x_t and parameters.

      Instances For
        def Spec.lstmSequenceSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar)) (initial_state : LSTMState α hiddenSize) :
        SequenceTensor α seqLen (Shape.dim hiddenSize Shape.scalar) × LSTMState α hiddenSize

        Run an LSTM cell over a length-seqLen input sequence, returning outputs and final state.

        Instances For
          @[irreducible]
          def Spec.lstmSequenceSpec.process_sequence {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar)) (t : ) (prev_state : LSTMState α hiddenSize) :
          LSTMState α hiddenSize × List (HiddenVector α hiddenSize)
          Instances For
            def Spec.lstmBatchedSpec {α : Type} [Context α] {batchSize seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : BatchedTensor α batchSize (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (initial_hiddens : BatchedTensor α batchSize (Shape.dim hiddenSize Shape.scalar)) :
            BatchedTensor α batchSize (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)) × BatchedTensor α batchSize (Shape.dim hiddenSize Shape.scalar)

            Batched wrapper around lstmSequenceSpec (runs one sequence per batch element).

            Instances For
              def Spec.lstmCellSpecWithIntermediates {α : Type} [Context α] {inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (input : InputVector α inputSize) (prev_state : LSTMState α hiddenSize) :
              LSTMState α hiddenSize × HiddenVector α hiddenSize × HiddenVector α hiddenSize × HiddenVector α hiddenSize × HiddenVector α hiddenSize

              Forward pass for one LSTM cell that also returns the gate activations.

              This is the spec analogue of the "saved tensors" that a runtime will keep for backward.

              Instances For
                def Spec.lstmCellBackwardSpec {α : Type} [Context α] {inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (input : InputVector α inputSize) (prev_state state : LSTMState α hiddenSize) (forget_gate input_gate candidate output_gate grad_hidden grad_cell : HiddenVector α hiddenSize) :
                InputVector α inputSize × LSTMState α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize

                Backward pass (VJP) for a single LSTM cell.

                Inputs:

                • parameters lstm,
                • inputs x_t, previous state (h_{t-1}, c_{t-1}), and current state (h_t, c_t),
                • the gate activations from the forward pass,
                • upstream gradients for both h_t and c_t.

                Outputs:

                • gradients w.r.t. x_t and the previous state,
                • plus gradients for each parameter tensor.

                PyTorch mental model: this is what autograd computes for nn.LSTMCell when unrolled in time.

                Instances For
                  def Spec.lstmSequenceBackwardSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar)) (initial_state : LSTMState α hiddenSize) (grad_hiddens : SequenceTensor α seqLen (Shape.dim hiddenSize Shape.scalar)) :
                  WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar) × LSTMState α hiddenSize

                  Backprop through time (BPTT) for the whole sequence.

                  This function recomputes and stores the forward intermediates (gates and states) internally, then walks time backward accumulating parameter gradients and input gradients. This matches the usual PyTorch training structure, with the save-vs-recompute choice made explicit.

                  Instances For
                    @[irreducible]
                    def Spec.lstmSequenceBackwardSpec.forward_collect {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar)) (t : ) (st : LSTMState α hiddenSize) :
                    LSTMState α hiddenSize × List (LSTMState α hiddenSize) × List (HiddenVector α hiddenSize) × List (HiddenVector α hiddenSize) × List (HiddenVector α hiddenSize) × List (HiddenVector α hiddenSize)
                    Instances For
                      @[irreducible]
                      def Spec.lstmSequenceBackwardSpec.backward_step {α : Type} [Context α] {seqLen inputSize hiddenSize : } (lstm : LSTMSpec α inputSize hiddenSize) (inputs : SequenceTensor α seqLen (Shape.dim inputSize Shape.scalar)) (initial_state : LSTMState α hiddenSize) (grad_hiddens : SequenceTensor α seqLen (Shape.dim hiddenSize Shape.scalar)) (state_seq : List (LSTMState α hiddenSize)) (f_seq i_seq g_seq o_seq : List (HiddenVector α hiddenSize)) (t : ) (_h_t : t seqLen) (dH_next dC_next : HiddenVector α hiddenSize) (acc_inputs : List (InputVector α inputSize)) (accWf : WeightMatrix α hiddenSize (inputSize + hiddenSize)) (accbf : HiddenVector α hiddenSize) (accWi : WeightMatrix α hiddenSize (inputSize + hiddenSize)) (accbi : HiddenVector α hiddenSize) (accWc : WeightMatrix α hiddenSize (inputSize + hiddenSize)) (accbc : HiddenVector α hiddenSize) (accWo : WeightMatrix α hiddenSize (inputSize + hiddenSize)) (accbo : HiddenVector α hiddenSize) :
                      List (InputVector α inputSize) × HiddenVector α hiddenSize × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize × WeightMatrix α hiddenSize (inputSize + hiddenSize) × HiddenVector α hiddenSize
                      Instances For