TorchLean API

NN.Proofs.Autograd.Tape.Ops.Recurrent.ElmanCell

Elman RNN Cell VJP #

This file proves the core differentiable cell used by a vanilla tanh RNN:

h' = tanh(W [x; h] + b).

The theorem is deliberately cell-level. Runtime sequence layers unroll this cell over time and scatter the hidden states into an output sequence; the full BPTT theorem is the induction over that unroll plus the existing gather/scatter adjoint facts. We prove the cell first because it is the right reusable grain size: vector concatenation, affine maps, and smooth elementwise tanh.

References:

@[reducible, inline]

Input vector shape for one RNN step.

Instances For
    @[reducible, inline]

    Hidden vector shape for one RNN step.

    Instances For
      @[reducible, inline]
      abbrev Proofs.Autograd.Recurrent.ΓElman (inputSize hiddenSize : ) :

      Context for a one-step Elman cell: current input and previous hidden state.

      Instances For
        @[reducible, inline]
        abbrev Proofs.Autograd.Recurrent.ssElmanCell (inputSize hiddenSize : ) :

        Saved tensors: concatenated [x; h], affine preactivation, and next hidden state.

        Instances For
          def Proofs.Autograd.Recurrent.idxInput {inputSize hiddenSize : } {ss : List Spec.Shape} :
          Idx (ΓElman inputSize hiddenSize ++ ss) (XShape inputSize)

          Current input index.

          Instances For
            def Proofs.Autograd.Recurrent.idxHidden {inputSize hiddenSize : } {ss : List Spec.Shape} :
            Idx (ΓElman inputSize hiddenSize ++ ss) (HShape hiddenSize)

            Previous hidden-state index.

            Instances For

              Most recently appended tensor helper.

              Instances For
                def Proofs.Autograd.Recurrent.idxConcat {inputSize hiddenSize : } :
                Idx (ΓElman inputSize hiddenSize ++ [Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar]) (Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar)

                Index of the concatenated [x; h] vector.

                Instances For
                  def Proofs.Autograd.Recurrent.idxPre {inputSize hiddenSize : } :
                  Idx (ΓElman inputSize hiddenSize ++ [Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar, HShape hiddenSize]) (HShape hiddenSize)

                  Index of the affine preactivation.

                  Instances For
                    noncomputable def Proofs.Autograd.Recurrent.elmanCellDGraph {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) :
                    DGraph (ΓElman inputSize hiddenSize) (ssElmanCell inputSize hiddenSize)

                    Proof-carrying graph for one Elman RNN cell.

                    The affine map is represented by a fixed LinearSpec, so this theorem covers the VJP with respect to the cell inputs (x, h). Parameter-gradient theorems are a separate layer over the trainable runtime parameter list.

                    Instances For
                      theorem Proofs.Autograd.Recurrent.elmanCell_backpropVec_eq_adjoint_fderiv {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (xV : CtxVec (ΓElman inputSize hiddenSize)) (seedV : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)) :

                      End-to-end VJP theorem for one vanilla RNN cell.

                      This is the recurrent analogue of the attention block theorems: the graph-level reverse pass equals the adjoint of the Fréchet derivative of the cell evaluation function.

                      theorem Proofs.Autograd.Recurrent.elmanCell_eval_hasFDerivAt {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (xV : CtxVec (ΓElman inputSize hiddenSize)) :

                      Forward evaluation of one Elman cell is differentiable at every input context.

                      This is the recurrent analogue of the Transformer sublayer calculus bridges: it exposes the cell as a differentiable map that can be composed repeatedly when proving BPTT for an unrolled RNN.

                      theorem Proofs.Autograd.Recurrent.elmanTwoStep_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (firstCtx : ECtxVec (ΓElman inputSize hiddenSize)) (DfirstCtx : E →L[] CtxVec (ΓElman inputSize hiddenSize)) (secondCtx : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize)) (DsecondCtx : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize) →L[] CtxVec (ΓElman inputSize hiddenSize)) (x : E) (hFirstCtx : HasFDerivAt firstCtx DfirstCtx x) (hSecondCtx : HasFDerivAt secondCtx DsecondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx x))) :
                      HasFDerivAt (fun (z : E) => (elmanCellDGraph cell).g.evalVec (secondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx z)))) ((fderiv (elmanCellDGraph cell).g.evalVec (secondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx x)))).comp (DsecondCtx.comp ((fderiv (elmanCellDGraph cell).g.evalVec (firstCtx x)).comp DfirstCtx))) x

                      Two-step recurrent composition bridge.

                      Suppose firstCtx builds the first cell context from some outer state E, and secondCtx builds the next cell context from the evaluated first-cell graph (for example by selecting the next input and the hidden state produced by the first step). If both context builders are differentiable, then the two-cell unroll is differentiable.

                      The theorem is intentionally abstract over secondCtx: different sequence layouts store inputs, hidden states, and caches differently, but every vanilla RNN BPTT proof follows this same chain-rule shape.