Cross Entropy #
One-hot cross entropy over the last axis, written as target · log_softmax(logits).
noncomputable def
Proofs.Autograd.TapeNodes.crossEntropyOneHotLast
{Γ : List Spec.Shape}
{m n : ℕ}
(logits target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
Cross-entropy loss for logits and one-hot targets of shape (m×n).
Forward:
-(1/m) * ⟪target, log_softmax_last(logits)⟫
This matches the common PyTorch cross_entropy convention with one-hot targets,
using log_softmax on logits (numerically stable vs log(softmax) for floats; here ℝ).
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.crossEntropyOneHotLastFderiv
{Γ : List Spec.Shape}
{m n : ℕ}
(logits target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
NodeFDerivCorrect (crossEntropyOneHotLast logits target)
NodeFDerivCorrect for cross_entropy_one_hot_last (one-hot targets; last-axis reduction).