TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Softmax

Last-axis softmax and log-softmax tape nodes #

Softmax and log-softmax over matrix rows, together with their Jacobian/VJP lemmas and NodeFDerivCorrect wrappers for graph-level autograd proofs.

@[reducible, inline]

Flattened size for an m×n matrix when viewed as a single vector (m*n).

Instances For
    noncomputable def Proofs.Autograd.TapeNodes.SoftmaxLastAxis.rows {m n : } (x : Vec (MNSize m n)) :
    Fin mVec n

    Split a flattened m*n vector into m rows of length n.

    Instances For
      noncomputable def Proofs.Autograd.TapeNodes.SoftmaxLastAxis.unrows {m n : } (r : Fin mVec n) :
      Vec (MNSize m n)

      Inverse of rows: assemble m rows back into a flattened m*n vector.

      Instances For

        Continuous-linear-map version of rows.

        Instances For

          Continuous-linear-map version of unrows.

          Instances For
            noncomputable def Proofs.Autograd.TapeNodes.SoftmaxLastAxis.forwardMN {m n : } (x : Vec (MNSize m n)) :
            Vec (MNSize m n)

            Apply softmaxVec independently to each row of an m×n matrix (flattened representation).

            Instances For
              noncomputable def Proofs.Autograd.TapeNodes.SoftmaxLastAxis.jvpMN {m n : } (x dx : Vec (MNSize m n)) :
              Vec (MNSize m n)

              JVP of forwardMN, computed rowwise.

              Instances For
                noncomputable def Proofs.Autograd.TapeNodes.SoftmaxLastAxis.derivMN {m n : } (x : Vec (MNSize m n)) :

                Derivative of forwardMN as a continuous linear map.

                Instances For

                  HasFDerivAt statement for forwardMN, using the rowwise softmax derivative.

                  JVP computed by jvpMN agrees with applying the derivative derivMN.

                  theorem Proofs.Autograd.TapeNodes.SoftmaxLastAxis.inner_jvpMN_comm {m n : } (x dx δ : Vec (MNSize m n)) :
                  inner (jvpMN x dx) δ = inner dx (jvpMN x δ)

                  Symmetry property of the JVP under the inner product (rowwise softmaxJvp commutation).

                  @[reducible, inline]

                  Reuse the m*n flattened size from SoftmaxLastAxis.

                  Instances For
                    noncomputable def Proofs.Autograd.TapeNodes.LogSoftmaxLastAxis.forwardMN {m n : } (x : Vec (MNSize m n)) :
                    Vec (MNSize m n)

                    Apply logSoftmaxVec independently to each row (flattened representation).

                    Instances For
                      noncomputable def Proofs.Autograd.TapeNodes.LogSoftmaxLastAxis.jvpMN {m n : } (x dx : Vec (MNSize m n)) :
                      Vec (MNSize m n)

                      JVP of forwardMN, computed rowwise via logSoftmaxJvp.

                      Instances For
                        noncomputable def Proofs.Autograd.TapeNodes.LogSoftmaxLastAxis.vjpMN {m n : } (x δ : Vec (MNSize m n)) :
                        Vec (MNSize m n)

                        VJP of forwardMN, computed rowwise via logSoftmaxVjp.

                        Instances For

                          Derivative of forwardMN as a continuous linear map.

                          Instances For

                            HasFDerivAt statement for logSoftmaxVec applied rowwise.

                            JVP computed by jvpMN agrees with applying the derivative derivMN.

                            logSoftmaxJvp / logSoftmaxVjp adjointness under the inner product, lifted rowwise.

                            Tape node for applying softmaxVec along the last axis of an m×n matrix (rowwise).

                            Instances For

                              Tape node for applying logSoftmaxVec along the last axis of an m×n matrix (rowwise).

                              Instances For

                                NodeFDerivCorrect for softmax_last (rowwise softmax).

                                Instances For

                                  NodeFDerivCorrect for log_softmax_last (rowwise log-softmax).

                                  Instances For