GRU (spec layer) #
TorchLean provides a small GRU specification that is:
- explicit about shapes (so dimension mistakes are caught early),
- explicit about the math (so we can reason about it and differentiate it),
- close in spirit to PyTorch's
nn.GRUCell/nn.GRUdocumentation.
References (math + PyTorch behavior) #
- Cho et al., "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" (EMNLP 2014): https://aclanthology.org/D14-1179/ (PDF: https://aclanthology.org/D14-1179.pdf)
- Chung et al., "Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling" (2014): https://arxiv.org/abs/1412.3555
- PyTorch
GRUCellequations: https://docs.pytorch.org/docs/stable/generated/torch.nn.GRUCell.html - PyTorch
GRUequations: https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.rnn.GRU.html
Notes on parameterization #
The GRU equations are often written with separate matrices W_* for the input and U_* for the
hidden state. In this spec we use a single matrix per gate applied to a concatenated vector
[x_t; h_{t-1}] (or [x_t; r_t ⊙ h_{t-1}] for the candidate). This is the same idea, just packaged
in a way that reuses the tensor building blocks already present in the spec layer.
One small place where libraries differ is the candidate equation: some implementations apply the reset gate before the hidden-state linear map (as in Cho et al.), while others apply it after a hidden-state linear map (as in the PyTorch docs). This file follows the former, because it matches the original GRU equations and stays close to the "concatenate then multiply" style used elsewhere in TorchLean's spec layer.
Parameters for a single GRU cell.
This is the spec-level analogue of PyTorch torch.nn.GRUCell parameters, using a concatenated
input [x_t; h_{t-1}] (shape inputSize + hiddenSize) for the reset/update gates and
[x_t; r_t ⊙ h_{t-1}] for the candidate gate.
Shapes:
- each
*_weightsis[hiddenSize, inputSize + hiddenSize], - each
*_biasis[hiddenSize].
- reset_weights : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
Reset-gate weights for
r_t = sigmoid(W_r [x_t; h_{t-1}] + b_r). - reset_bias : Tensor α (Shape.dim hiddenSize Shape.scalar)
Reset-gate bias.
- update_weights : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
Update-gate weights for
z_t = sigmoid(W_z [x_t; h_{t-1}] + b_z). - update_bias : Tensor α (Shape.dim hiddenSize Shape.scalar)
Update-gate bias.
- new_weights : Tensor α (Shape.dim hiddenSize (Shape.dim (inputSize + hiddenSize) Shape.scalar))
Candidate-state weights for
n_t = tanh(W_n [x_t; r_t ⊙ h_{t-1}] + b_n). - new_bias : Tensor α (Shape.dim hiddenSize Shape.scalar)
Candidate-state bias.
Instances For
Forward pass for a single GRU cell.
Given input x_t and previous hidden state h_{t-1}, compute the next hidden state h_t using
the standard GRU equations.
PyTorch analogue: torch.nn.GRUCell forward (see module header links).
Instances For
Unroll a GRU over seqLen timesteps (time-major).
This returns the sequence of hidden states [h_0, ..., h_{seqLen-1}]. It is a pure spec-level
definition of semantics; an efficient runtime is free to implement the same behavior with loops and
caching.
PyTorch analogue: torch.nn.GRU run on a time-major input (or batch_first=false), returning the
output sequence.
Instances For
Instances For
GRU cell forward pass that also returns cached intermediates for BPTT.
This computes the same next hidden state as gru_cell_spec, but additionally returns:
reset_gate(r_t),update_gate(z_t),new_candidate(n_t), andreset_hidden(r_t ⊙ h_{t-1}).
These are exactly the quantities commonly saved by a reverse-mode implementation (PyTorch-style autograd) to compute gradients efficiently in the backward pass.
Instances For
Run a GRU forward pass while collecting the per-timestep intermediates needed for BPTT.
This is the "spec-level" analogue of what frameworks do internally:
- the forward pass produces
h_t, - and it also saves gate activations (
r_t,z_t) and the candidate (n_t) for the backward pass.
The returned tensors are all time-major (seqLen first) to match the rest of the spec layer.
Instances For
Instances For
Batched GRU forward pass (map gruSequenceSpec over the batch dimension).
This is a simple spec-level definition for semantics, not an optimized kernel.
PyTorch analogue: torch.nn.GRU on a batched input tensor.
Instances For
Reference gradient for reset-gate weights via the generic RNN weight-gradient helper.
This uses rnn_weights_deriv_spec on the concatenated inputs/hidden states. It is a convenient
building block, but the more explicit *_bptt_spec helpers below show the time-unrolled
accumulation form.
Instances For
Reference gradient for update-gate weights (via rnn_weights_deriv_spec).
Instances For
Reference gradient for candidate ("new") gate weights (via rnn_weights_deriv_spec).
Note the second sequence argument is reset_hiddens = r_t ⊙ h_{t-1}.
Instances For
Bias gradient by summing per-timestep gradients over the time axis.
This is the spec-level analogue of the common "sum across batch/time" reduction used for bias
gradients. The seqLen ≠ 0 hypothesis is exactly what makes axis 0 a valid reduction axis for
reduce_sum_auto in the shape-indexed tensor API.
Instances For
Reset-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).
This computes Σ_t (dL/dr_t) ⊗ [x_t; h_{t-1}], where ⊗ is an outer product.
Instances For
Instances For
Update-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).
This computes Σ_t (dL/dz_t) ⊗ [x_t; h_{t-1}].
Instances For
Instances For
Candidate-gate weight gradient by explicit time-unrolled accumulation (BPTT-style).
This computes Σ_t (dL/dn_t) ⊗ [x_t; r_t ⊙ h_{t-1}].
Instances For
Instances For
Backward (VJP) for a single GRU cell.
Inputs:
- the cell parameters
gru, - the current input
x_t, - the previous hidden state
h_{t-1}, - an upstream gradient
dL/dh_t, - and the forward intermediates (
r_t,z_t,n_t) that a typical BPTT implementation would cache.
Outputs:
- gradients w.r.t. the input and previous hidden state,
- plus gradients for each parameter tensor (weights and biases).
This is written to match the forward equations in gru_cell_spec. It is not an optimized kernel;
it is a precise spec for what gradients should be.
Instances For
Reverse-mode backprop through an unrolled GRU over seqLen steps (BPTT).
This function consumes the same intermediates produced by gru_extract_intermediate_values:
per-timestep gate activations and candidates. This mirrors the PyTorch mental model: the forward
pass produces a sequence of hidden states and saves what it needs; the backward pass walks time
in reverse and accumulates gradients.
Instances For
Instances For
Convenience wrapper: return only (dInputs, dInitialHidden) from gruSequenceBackwardFullSpec.
This is useful when you only need gradients w.r.t. the input sequence and initial hidden state, and not the full parameter-gradient bundle.