TorchLean API

NN.Spec.Layers.Gru

GRU (spec layer) #

TorchLean provides a small GRU specification that is:

References (math + PyTorch behavior) #

Notes on parameterization #

The GRU equations are often written with separate matrices W_* for the input and U_* for the hidden state. In this spec we use a single matrix per gate applied to a concatenated vector [x_t; h_{t-1}] (or [x_t; r_t ⊙ h_{t-1}] for the candidate). This is the same idea, just packaged in a way that reuses the tensor building blocks already present in the spec layer.

One small place where libraries differ is the candidate equation: some implementations apply the reset gate before the hidden-state linear map (as in Cho et al.), while others apply it after a hidden-state linear map (as in the PyTorch docs). This file follows the former, because it matches the original GRU equations and stays close to the "concatenate then multiply" style used elsewhere in TorchLean's spec layer.

structure Spec.GRUSpec (α : Type) (inputSize hiddenSize : ) :

Parameters for a single GRU cell.

This is the spec-level analogue of PyTorch torch.nn.GRUCell parameters, using a concatenated input [x_t; h_{t-1}] (shape inputSize + hiddenSize) for the reset/update gates and [x_t; r_t ⊙ h_{t-1}] for the candidate gate.

Shapes:

  • each *_weights is [hiddenSize, inputSize + hiddenSize],
  • each *_bias is [hiddenSize].
Instances For
    def Spec.gruCellSpec {α : Type} [Context α] {inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (input : Tensor α (Shape.dim inputSize Shape.scalar)) (prev_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
    Tensor α (Shape.dim hiddenSize Shape.scalar)

    Forward pass for a single GRU cell.

    Given input x_t and previous hidden state h_{t-1}, compute the next hidden state h_t using the standard GRU equations.

    PyTorch analogue: torch.nn.GRUCell forward (see module header links).

    Instances For
      def Spec.gruSequenceSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (initial_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
      Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))

      Unroll a GRU over seqLen timesteps (time-major).

      This returns the sequence of hidden states [h_0, ..., h_{seqLen-1}]. It is a pure spec-level definition of semantics; an efficient runtime is free to implement the same behavior with loops and caching.

      PyTorch analogue: torch.nn.GRU run on a time-major input (or batch_first=false), returning the output sequence.

      Instances For
        @[irreducible]
        def Spec.gruSequenceSpec.process_sequence {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (t : ) (prev_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
        Tensor α (Shape.dim hiddenSize Shape.scalar) × List (Tensor α (Shape.dim hiddenSize Shape.scalar))
        Instances For
          def Spec.gruCellSpecWithIntermediates {α : Type} [Context α] {inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (input : Tensor α (Shape.dim inputSize Shape.scalar)) (prev_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
          Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize Shape.scalar)

          GRU cell forward pass that also returns cached intermediates for BPTT.

          This computes the same next hidden state as gru_cell_spec, but additionally returns:

          • reset_gate (r_t),
          • update_gate (z_t),
          • new_candidate (n_t), and
          • reset_hidden (r_t ⊙ h_{t-1}).

          These are exactly the quantities commonly saved by a reverse-mode implementation (PyTorch-style autograd) to compute gradients efficiently in the backward pass.

          Instances For
            def Spec.gruExtractIntermediateValues {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (initial_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
            Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)) × Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)) × Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)) × Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)) × Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))

            Run a GRU forward pass while collecting the per-timestep intermediates needed for BPTT.

            This is the "spec-level" analogue of what frameworks do internally:

            • the forward pass produces h_t,
            • and it also saves gate activations (r_t, z_t) and the candidate (n_t) for the backward pass.

            The returned tensors are all time-major (seqLen first) to match the rest of the spec layer.

            Instances For
              @[irreducible]
              def Spec.gruExtractIntermediateValues.process_sequence {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (t : ) (prev_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
              Tensor α (Shape.dim hiddenSize Shape.scalar) × List (Tensor α (Shape.dim hiddenSize Shape.scalar)) × List (Tensor α (Shape.dim hiddenSize Shape.scalar)) × List (Tensor α (Shape.dim hiddenSize Shape.scalar)) × List (Tensor α (Shape.dim hiddenSize Shape.scalar)) × List (Tensor α (Shape.dim hiddenSize Shape.scalar))
              Instances For
                def Spec.gruBatchedSpec {α : Type} [Context α] {batchSize seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim batchSize (Shape.dim seqLen (Shape.dim inputSize Shape.scalar)))) (initial_hidden : Tensor α (Shape.dim batchSize (Shape.dim hiddenSize Shape.scalar))) :
                Tensor α (Shape.dim batchSize (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar)))

                Batched GRU forward pass (map gruSequenceSpec over the batch dimension).

                This is a simple spec-level definition for semantics, not an optimized kernel. PyTorch analogue: torch.nn.GRU on a batched input tensor.

                Instances For
                  def Spec.gruResetWeightsDerivSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_reset : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                  Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                  Reference gradient for reset-gate weights via the generic RNN weight-gradient helper.

                  This uses rnn_weights_deriv_spec on the concatenated inputs/hidden states. It is a convenient building block, but the more explicit *_bptt_spec helpers below show the time-unrolled accumulation form.

                  Instances For
                    def Spec.gruUpdateWeightsDerivSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_update : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                    Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                    Reference gradient for update-gate weights (via rnn_weights_deriv_spec).

                    Instances For
                      def Spec.gruNewWeightsDerivSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (reset_hiddens grad_new : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                      Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                      Reference gradient for candidate ("new") gate weights (via rnn_weights_deriv_spec).

                      Note the second sequence argument is reset_hiddens = r_t ⊙ h_{t-1}.

                      Instances For
                        def Spec.gruBiasDerivSpec {α : Type} [Context α] {seqLen hiddenSize : } (grad_outputs : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (h : seqLen 0) :
                        Tensor α (Shape.dim hiddenSize Shape.scalar)

                        Bias gradient by summing per-timestep gradients over the time axis.

                        This is the spec-level analogue of the common "sum across batch/time" reduction used for bias gradients. The seqLen ≠ 0 hypothesis is exactly what makes axis 0 a valid reduction axis for reduce_sum_auto in the shape-indexed tensor API.

                        Instances For
                          def Spec.gruResetWeightsDerivBpttSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens _reset_gates grad_reset_gates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                          Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                          Reset-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).

                          This computes Σ_t (dL/dr_t) ⊗ [x_t; h_{t-1}], where is an outer product.

                          Instances For
                            @[irreducible]
                            def Spec.gruResetWeightsDerivBpttSpec.accumulate_grads {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_reset_gates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (t : ) (acc : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) :
                            Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
                            Instances For
                              def Spec.gruUpdateWeightsDerivBpttSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_update_gates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                              Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                              Update-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).

                              This computes Σ_t (dL/dz_t) ⊗ [x_t; h_{t-1}].

                              Instances For
                                @[irreducible]
                                def Spec.gruUpdateWeightsDerivBpttSpec.accumulate_grads {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_update_gates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (t : ) (acc : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) :
                                Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
                                Instances For
                                  def Spec.gruNewWeightsDerivBpttSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (reset_hiddens grad_new_candidates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) :
                                  Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))

                                  Candidate-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).

                                  This computes Σ_t (dL/dn_t) ⊗ [x_t; r_t ⊙ h_{t-1}].

                                  Instances For
                                    @[irreducible]
                                    def Spec.gruNewWeightsDerivBpttSpec.accumulate_grads {α : Type} [Context α] {seqLen inputSize hiddenSize : } (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (reset_hiddens grad_new_candidates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (t : ) (acc : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) :
                                    Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
                                    Instances For
                                      def Spec.gruCellBackwardFullSpec {α : Type} [Context α] {inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (input : Tensor α (Shape.dim inputSize Shape.scalar)) (prev_hidden grad_output reset_gate update_gate new_candidate : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
                                      Tensor α (Shape.dim inputSize Shape.scalar) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar)

                                      Backward (VJP) for a single GRU cell.

                                      Inputs:

                                      • the cell parameters gru,
                                      • the current input x_t,
                                      • the previous hidden state h_{t-1},
                                      • an upstream gradient dL/dh_t,
                                      • and the forward intermediates (r_t, z_t, n_t) that a typical BPTT implementation would cache.

                                      Outputs:

                                      • gradients w.r.t. the input and previous hidden state,
                                      • plus gradients for each parameter tensor (weights and biases).

                                      This is written to match the forward equations in gru_cell_spec. It is not an optimized kernel; it is a precise spec for what gradients should be.

                                      Instances For
                                        def Spec.gruSequenceBackwardFullSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_outputs reset_gates update_gates new_candidates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (initial_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar) := fill 0 (Shape.dim hiddenSize Shape.scalar)) :
                                        Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar)

                                        Reverse-mode backprop through an unrolled GRU over seqLen steps (BPTT).

                                        This function consumes the same intermediates produced by gru_extract_intermediate_values: per-timestep gate activations and candidates. This mirrors the PyTorch mental model: the forward pass produces a sequence of hidden states and saves what it needs; the backward pass walks time in reverse and accumulates gradients.

                                        Instances For
                                          @[irreducible]
                                          def Spec.gruSequenceBackwardFullSpec.backward_step {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_outputs reset_gates update_gates new_candidates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (initial_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar)) (t : ) (_h_t : t seqLen) (dHidden_next : Tensor α (Shape.dim hiddenSize Shape.scalar)) (acc_inputs : List (Tensor α (Shape.dim inputSize Shape.scalar))) (dResetW : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) (dResetB : Tensor α (Shape.dim hiddenSize Shape.scalar)) (dUpdateW : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) (dUpdateB : Tensor α (Shape.dim hiddenSize Shape.scalar)) (dNewW : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))) (dNewB : Tensor α (Shape.dim hiddenSize Shape.scalar)) :
                                          List (Tensor α (Shape.dim inputSize Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar) × Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar)
                                          Instances For
                                            def Spec.gruSequenceBackwardSpec {α : Type} [Context α] {seqLen inputSize hiddenSize : } (gru : GRUSpec α inputSize hiddenSize) (inputs : Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar))) (hiddens grad_outputs reset_gates update_gates new_candidates : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize Shape.scalar))) (initial_hidden : Tensor α (Shape.dim hiddenSize Shape.scalar) := fill 0 (Shape.dim hiddenSize Shape.scalar)) :
                                            Tensor α (Shape.dim seqLen (Shape.dim inputSize Shape.scalar)) × Tensor α (Shape.dim hiddenSize Shape.scalar)

                                            Convenience wrapper: return only (dInputs, dInitialHidden) from gruSequenceBackwardFullSpec.

                                            This is useful when you only need gradients w.r.t. the input sequence and initial hidden state, and not the full parameter-gradient bundle.

                                            Instances For