TorchLean API

NN.Proofs.Autograd.Tape.Ops.Recurrent.ElmanCell

Elman RNN Cell VJP #

This file proves the core differentiable cell used by a vanilla tanh RNN:

h' = tanh(W [x; h] + b).

The theorem is deliberately cell-level. Runtime sequence layers unroll this cell over time and scatter the hidden states into an output sequence; the full BPTT theorem is the induction over that unroll plus the existing gather/scatter adjoint facts. We prove the cell first because it is the right reusable grain size: vector concatenation, affine maps, and smooth elementwise tanh.

References:

@[reducible, inline]

Input vector shape for one RNN step.

Instances For
    @[reducible, inline]

    Hidden vector shape for one RNN step.

    Instances For
      @[reducible, inline]
      abbrev Proofs.Autograd.Recurrent.ΓElman (inputSize hiddenSize : ) :

      Context for a one-step Elman cell: current input and previous hidden state.

      Instances For
        @[reducible, inline]
        abbrev Proofs.Autograd.Recurrent.ssElmanCell (inputSize hiddenSize : ) :

        Saved tensors: concatenated [x; h], affine preactivation, and next hidden state.

        Instances For
          def Proofs.Autograd.Recurrent.idxInput {inputSize hiddenSize : } {ss : List Spec.Shape} :
          Idx (ΓElman inputSize hiddenSize ++ ss) (XShape inputSize)

          Index of the current input vector in the Elman-cell context.

          Instances For
            def Proofs.Autograd.Recurrent.idxHidden {inputSize hiddenSize : } {ss : List Spec.Shape} :
            Idx (ΓElman inputSize hiddenSize ++ ss) (HShape hiddenSize)

            Previous hidden-state index.

            Instances For
              def Proofs.Autograd.Recurrent.idxConcat {inputSize hiddenSize : } :
              Idx (ΓElman inputSize hiddenSize ++ [Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar]) (Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar)

              Index of the concatenated [x; h] vector.

              Instances For
                def Proofs.Autograd.Recurrent.idxPre {inputSize hiddenSize : } :
                Idx (ΓElman inputSize hiddenSize ++ [Spec.Shape.dim (inputSize + hiddenSize) Spec.Shape.scalar, HShape hiddenSize]) (HShape hiddenSize)

                Index of the affine preactivation.

                Instances For
                  noncomputable def Proofs.Autograd.Recurrent.elmanCellDGraph {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) :
                  DGraph (ΓElman inputSize hiddenSize) (ssElmanCell inputSize hiddenSize)

                  Proof-carrying graph for one Elman RNN cell.

                  The affine map is represented by a fixed LinearSpec, so this theorem covers the VJP with respect to the cell inputs (x, h). Parameter-gradient theorems are a separate layer over the trainable runtime parameter list.

                  Instances For
                    theorem Proofs.Autograd.Recurrent.elmanCell_backpropVec_eq_adjoint_fderiv {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (xV : CtxVec (ΓElman inputSize hiddenSize)) (seedV : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)) :

                    End-to-end VJP theorem for one vanilla RNN cell.

                    This is the recurrent analogue of the attention block theorems: the graph-level reverse pass equals the adjoint of the Fréchet derivative of the cell evaluation function.

                    theorem Proofs.Autograd.Recurrent.elmanCell_eval_hasFDerivAt {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (xV : CtxVec (ΓElman inputSize hiddenSize)) :

                    Forward evaluation of one Elman cell is differentiable at every input context.

                    This is the recurrent analogue of the Transformer sublayer calculus bridges: it exposes the cell as a differentiable map that can be composed repeatedly when proving BPTT for an unrolled RNN.

                    theorem Proofs.Autograd.Recurrent.elmanTwoStep_hasFDerivAt {E : Type u} [NormedAddCommGroup E] [NormedSpace E] {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (firstCtx : ECtxVec (ΓElman inputSize hiddenSize)) (DfirstCtx : E →L[] CtxVec (ΓElman inputSize hiddenSize)) (secondCtx : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize)) (DsecondCtx : CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize) →L[] CtxVec (ΓElman inputSize hiddenSize)) (x : E) (hFirstCtx : HasFDerivAt firstCtx DfirstCtx x) (hSecondCtx : HasFDerivAt secondCtx DsecondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx x))) :
                    HasFDerivAt (fun (z : E) => (elmanCellDGraph cell).g.evalVec (secondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx z)))) (fderiv (elmanCellDGraph cell).g.evalVec (secondCtx ((elmanCellDGraph cell).g.evalVec (firstCtx x))) ∘SL DsecondCtx ∘SL fderiv (elmanCellDGraph cell).g.evalVec (firstCtx x) ∘SL DfirstCtx) x

                    Two-step recurrent composition bridge.

                    Suppose firstCtx builds the first cell context from some outer state E, and secondCtx builds the next cell context from the evaluated first-cell graph (for example by selecting the next input and the hidden state produced by the first step). If both context builders are differentiable, then the two-cell unroll is differentiable.

                    The theorem is intentionally abstract over secondCtx: different sequence layouts store inputs, hidden states, and caches differently, but every vanilla RNN BPTT proof follows this same chain-rule shape.

                    noncomputable def Proofs.Autograd.Recurrent.elmanCellEval {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) :
                    CtxVec (ΓElman inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)

                    Forward evaluation map for a single Elman-cell graph.

                    Instances For
                      def Proofs.Autograd.Recurrent.elmanUnroll {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) :
                      List (CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize))CtxVec (ΓElman inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)

                      Unroll an Elman cell through an arbitrary list of transition builders.

                      Each transition consumes the full trace of the previous cell and constructs the next cell context. This covers the usual sequence case, where the transition selects the next input token/vector and threads the hidden state produced by the previous step.

                      Instances For
                        def Proofs.Autograd.Recurrent.elmanTransitionsDifferentiableAt {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) :
                        List (CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize))CtxVec (ΓElman inputSize hiddenSize)Prop

                        Local differentiability obligations for an arbitrary recurrent unroll.

                        For each transition, the obligation is stated at the trace actually produced by the previous cell. Keeping this as a recursive predicate makes the induction theorem independent of any particular sequence storage convention.

                        Instances For
                          theorem Proofs.Autograd.Recurrent.elmanUnroll_hasFDerivAt {inputSize hiddenSize : } (cell : Spec.LinearSpec (inputSize + hiddenSize) hiddenSize) (steps : List (CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)CtxVec (ΓElman inputSize hiddenSize))) (x : CtxVec (ΓElman inputSize hiddenSize)) (hSteps : elmanTransitionsDifferentiableAt cell steps x) :
                          ∃ (D : CtxVec (ΓElman inputSize hiddenSize) →L[] CtxVec (ΓElman inputSize hiddenSize ++ ssElmanCell inputSize hiddenSize)), HasFDerivAt (elmanUnroll cell steps) D x

                          Arbitrary-length BPTT chain-rule induction for an Elman RNN unroll.

                          If every transition between cells is differentiable at the trace reached during the forward pass, then the whole unrolled recurrence is differentiable at the initial cell context. This is the mathematical induction principle behind backpropagation through time for the vanilla tanh RNN cell.