Elman RNN Cell VJP #
This file proves the core differentiable cell used by a vanilla tanh RNN:
h' = tanh(W [x; h] + b).
The theorem is deliberately cell-level. Runtime sequence layers unroll this cell over time and
scatter the hidden states into an output sequence; the full BPTT theorem is the induction over that
unroll plus the existing gather/scatter adjoint facts. We prove the cell first because it is the
right reusable grain size: vector concatenation, affine maps, and smooth elementwise tanh.
References:
- Elman, "Finding Structure in Time", Cognitive Science 1990.
- PyTorch
torch.nn.RNN: https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
Input vector shape for one RNN step.
Instances For
Hidden vector shape for one RNN step.
Instances For
Context for a one-step Elman cell: current input and previous hidden state.
Instances For
Saved tensors: concatenated [x; h], affine preactivation, and next hidden state.
Instances For
Current input index.
Instances For
Previous hidden-state index.
Instances For
Most recently appended tensor helper.
Instances For
Index of the concatenated [x; h] vector.
Instances For
Index of the affine preactivation.
Instances For
Proof-carrying graph for one Elman RNN cell.
The affine map is represented by a fixed LinearSpec, so this theorem covers the VJP with respect
to the cell inputs (x, h). Parameter-gradient theorems are a separate layer over the trainable
runtime parameter list.
Instances For
End-to-end VJP theorem for one vanilla RNN cell.
This is the recurrent analogue of the attention block theorems: the graph-level reverse pass equals the adjoint of the Fréchet derivative of the cell evaluation function.
Forward evaluation of one Elman cell is differentiable at every input context.
This is the recurrent analogue of the Transformer sublayer calculus bridges: it exposes the cell as a differentiable map that can be composed repeatedly when proving BPTT for an unrolled RNN.
Two-step recurrent composition bridge.
Suppose firstCtx builds the first cell context from some outer state E, and secondCtx builds
the next cell context from the evaluated first-cell graph (for example by selecting the next input
and the hidden state produced by the first step). If both context builders are differentiable, then
the two-cell unroll is differentiable.
The theorem is intentionally abstract over secondCtx: different sequence layouts store inputs,
hidden states, and caches differently, but every vanilla RNN BPTT proof follows this same chain-rule
shape.