TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Losses

Loss tape nodes #

Differentiable loss nodes used by training and verification examples: MSE, cross entropy with one-hot targets, negative log likelihood, BCE-with-logits, and KL divergence.

noncomputable def Proofs.Autograd.TapeNodes.mseLoss {Γ : List Spec.Shape} {s : Spec.Shape} (yhat target : Idx Γ s) :

Mean-squared-error loss node: c * ‖yhat - target‖^2, with c = 1 / size(s).

Instances For
    noncomputable def Proofs.Autograd.TapeNodes.mseLossFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (yhat target : Idx Γ s) :

    NodeFDerivCorrect for mse_loss.

    Instances For

      Cross-entropy loss for logits and one-hot targets of shape (m×n).

      Forward: -(1/m) * ⟪target, log_softmax_last(logits)⟫

      This matches the common PyTorch cross_entropy convention with one-hot targets, using log_softmax on logits (numerically stable vs log(softmax) for floats; here ℝ).

      Instances For

        NodeFDerivCorrect for cross_entropy_one_hot_last (one-hot targets; last-axis reduction).

        Instances For

          Negative log-likelihood loss for log-probabilities and one-hot targets of shape (m×n).

          Forward: -(1/m) * ⟪target, logProbs⟫

          This is the natural primitive loss that cross_entropy reduces to after log_softmax.

          Instances For

            NodeFDerivCorrect for nll_one_hot_last (negative log-likelihood with one-hot targets).

            Instances For
              noncomputable def Proofs.Autograd.TapeNodes.bceWithLogits {Γ : List Spec.Shape} {s : Spec.Shape} (logits target : Idx Γ s) :

              Binary cross-entropy with logits for same-shaped logits/targets.

              Forward (mean reduction over all entries): (1/N) * Σ_i (softplus(logits_i) - target_i * logits_i)

              This matches PyTorch's BCEWithLogitsLoss with reduction="mean", and uses the stable identity BCEWithLogits(x,t) = softplus(x) - t*x.

              Instances For
                noncomputable def Proofs.Autograd.TapeNodes.bceWithLogitsFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (logits target : Idx Γ s) :

                NodeFDerivCorrect for bce_with_logits (binary cross-entropy with logits).

                Instances For

                  KL-divergence loss for logProbs and target probabilities of shape (m×n).

                  Forward (batchmean reduction): (1/m) * Σ_{i,j} target[i,j] * (log(target[i,j]) - logProbs[i,j])

                  This matches PyTorch KLDivLoss / F.kl_div with:

                  • input = log-probabilities,
                  • target = probabilities (not log-target),
                  • reduction="batchmean".

                  We use the Real.log/x⁻¹ derivative spec, so the node's VJP is correct on points where target entries are nonzero.

                  Instances For
                    noncomputable def Proofs.Autograd.TapeNodes.klDivLastFderivAt {Γ : List Spec.Shape} {m n : } (logProbs target : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))) (xV : CtxVec Γ) (ht : ∀ (i : Fin (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)).size), (CtxVec.get target xV).ofLp i 0) :
                    NodeFDerivCorrectAt (klDivLast logProbs target) xV

                    Pointwise NodeFDerivCorrectAt for kl_div_last, assuming target entries are nonzero.

                    Instances For