Binary Cross-Entropy With Logits #
Stable BCE-with-logits loss and its tape-level derivative proof.
noncomputable def
Proofs.Autograd.TapeNodes.bceWithLogits
{Γ : List Spec.Shape}
{s : Spec.Shape}
(logits target : Idx Γ s)
:
Binary cross-entropy with logits for same-shaped logits/targets.
Forward (mean reduction over all entries):
(1/N) * Σ_i (softplus(logits_i) - target_i * logits_i)
This matches PyTorch's BCEWithLogitsLoss with reduction="mean",
and uses the stable identity BCEWithLogits(x,t) = softplus(x) - t*x.
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.bceWithLogitsFderiv
{Γ : List Spec.Shape}
{s : Spec.Shape}
(logits target : Idx Γ s)
:
NodeFDerivCorrect (bceWithLogits logits target)
NodeFDerivCorrect for bce_with_logits (binary cross-entropy with logits).