TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Losses.KLDivergence

KL Divergence #

Batchmean KL-divergence for log-probability inputs and probability targets.

KL-divergence loss for logProbs and target probabilities of shape (m×n).

Forward (batchmean reduction): (1/m) * Σ_{i,j} target[i,j] * (log(target[i,j]) - logProbs[i,j])

This matches PyTorch KLDivLoss / F.kl_div with:

  • input = log-probabilities,
  • target = probabilities (not log-target),
  • reduction="batchmean".

We use the Real.log/x⁻¹ derivative spec, so the node's VJP is correct on points where target entries are nonzero.

Instances For
    noncomputable def Proofs.Autograd.TapeNodes.klDivLastFderivAt {Γ : List Spec.Shape} {m n : } (logProbs target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))) (xV : CtxVec Γ) (ht : ∀ (i : Fin (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)).size), (CtxVec.get target xV).ofLp i 0) :
    NodeFDerivCorrectAt (klDivLast logProbs target) xV

    Pointwise NodeFDerivCorrectAt for kl_div_last, assuming target entries are nonzero.

    Instances For