KL Divergence #
Batchmean KL-divergence for log-probability inputs and probability targets.
noncomputable def
Proofs.Autograd.TapeNodes.klDivLast
{Γ : List Spec.Shape}
{m n : ℕ}
(logProbs target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
KL-divergence loss for logProbs and target probabilities of shape (m×n).
Forward (batchmean reduction):
(1/m) * Σ_{i,j} target[i,j] * (log(target[i,j]) - logProbs[i,j])
This matches PyTorch KLDivLoss / F.kl_div with:
input= log-probabilities,target= probabilities (not log-target),reduction="batchmean".
We use the Real.log/x⁻¹ derivative spec, so the node's VJP is correct on points
where target entries are nonzero.
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.klDivLastFderivAt
{Γ : List Spec.Shape}
{m n : ℕ}
(logProbs target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
(xV : CtxVec Γ)
(ht : ∀ (i : Fin (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)).size), (CtxVec.get target xV).ofLp i ≠ 0)
:
NodeFDerivCorrectAt (klDivLast logProbs target) xV
Pointwise NodeFDerivCorrectAt for kl_div_last, assuming target entries are nonzero.