TorchLean API

NN.Proofs.Autograd.FDeriv.LogSoftmax

LogSoftmax #

Fréchet-derivative facts for log-softmax on Euclidean vectors.

This is the analytic (ℝ) ingredient used to justify log_softmax nodes (Vec n → Vec n) in the tape/DAG autograd proofs.

References #

theorem Proofs.Autograd.sumExp_pos {n : } (x : Vec n.succ) :
0 < sumExp x

sumExp x is strictly positive when the index type is nonempty.

Convenience corollary: sumExp x ≠ 0 (for n = succ _).

noncomputable def Proofs.Autograd.logSoftmaxVec {n : } :
Vec nVec n

Log-softmax on Euclidean vectors.

For n = succ _: logSoftmaxVec x i = xᵢ - log(sumExp x).

The n = 0 branch is the identity on the trivial space.

Instances For
    noncomputable def Proofs.Autograd.logSoftmaxDerivCoord {n : } (x : Vec n.succ) (i : Fin n.succ) :

    The ith output coordinate of the log-softmax derivative at x (for n = succ _).

    If y = softmaxVec x, then this is the linear functional dx ↦ dxᵢ - ⟪y, dx⟫.

    Instances For
      noncomputable def Proofs.Autograd.logSoftmaxDerivCLM {n : } :
      Vec nVec n →L[] Vec n

      The full Fréchet derivative of logSoftmaxVec at x, packaged as a CLM.

      Instances For
        noncomputable def Proofs.Autograd.logSoftmaxJvp {n : } :
        Vec nVec nVec n

        Closed-form JVP (directional derivative) for log-softmax.

        For n = succ _, if y = softmaxVec x and s = ⟪y, dx⟫, then (logSoftmaxJvp x dx)ᵢ = dxᵢ - s.

        Instances For
          noncomputable def Proofs.Autograd.logSoftmaxVjp {n : } :
          Vec nVec nVec n

          Closed-form VJP for log-softmax (transpose-Jacobian product).

          For n = succ _, if y = softmaxVec x and t = ∑ᵢ δᵢ, then (logSoftmaxVjp x δ)ᵢ = δᵢ - yᵢ * t.

          Instances For

            The closed-form JVP logSoftmaxJvp agrees with the CLM derivative logSoftmaxDerivCLM.

            Adjointness identity: the log-softmax JVP and VJP are adjoint w.r.t. the Euclidean inner product.

            This is the analytic statement that justifies using logSoftmaxVjp as backward.

            Log-softmax is Fréchet-differentiable everywhere, with derivative logSoftmaxDerivCLM.