TorchLean API

NN.Proofs.Autograd.Core.SemiringCorrectness

SemiringCorrectness #

Semiring-generic autograd correctness layer (backend-generic).

This mirrors NN/Proofs/Autograd/Core/RealCorrectness.lean, but avoids analytic assumptions and works over any commutative semiring. In particular, it applies to exact backends like .

The correctness notion is the standard reverse-mode / forward-mode adjointness law:

⟪ JVP(x, dx), δ ⟫ = ⟪ dx, VJP(x, δ) ⟫

where ⟪·,·⟫ is the tensor dot-product from NN/Proofs/Tensor/Algebra.lean.

Why this is separate from the ℝ file #

Many ML ops are definable over a commutative semiring (addition/multiplication/linear maps), and their reverse-mode rules can be proved from algebraic identities alone. This file isolates that “pure algebra” portion so it can be instantiated for exact backends (e.g. ) without pulling in real-analytic structure.

Ops that require extra structure (e.g. ReLU needs an order/max, MSE needs division by Shape.size) appear here only under the corresponding extra typeclass assumptions.

If you only care about real-valued training semantics, prefer NN.Proofs.Autograd.Core.RealCorrectness. If you want proofs that can be instantiated for exact backends (, etc.), prefer this file.

PyTorch correspondence / citations #

This is the proof-level analogue of the “VJP correctness” property implicitly relied upon by PyTorch Autograd: each primitive op must supply a correct local backward/VJP rule. https://pytorch.org/docs/stable/autograd.html

def Proofs.Autograd.Algebra.VJPCorrect {α : Type} [CommSemiring α] {σ τ : Spec.Shape} (_forward : Spec.Tensor α σSpec.Tensor α τ) (jvp : Spec.Tensor α σSpec.Tensor α σSpec.Tensor α τ) (vjp : Spec.Tensor α σSpec.Tensor α τSpec.Tensor α σ) :

VJP/JVP adjointness for a unary op σ → τ.

Instances For

    An OpSpec together with a matching JVP and a proof of VJP/JVP adjointness.

    This is the backend-generic analogue of Proofs.Autograd.OpSpecCorrect from NN.Proofs.Autograd.Core.RealCorrectness.

    Instances For
      def Proofs.Autograd.Algebra.OpSpecCorrect.compose {α : Type} [CommSemiring α] {σ τ υ : Spec.Shape} (f : OpSpecCorrect α σ τ) (g : OpSpecCorrect α τ υ) :
      OpSpecCorrect α σ υ

      Composition preserves VJP/JVP correctness (reverse-mode chain rule).

      Informally: if f and g satisfy ⟪JVP,·⟫ = ⟪·,VJP⟫, then so does g ∘ f, with the obvious composed JVP and VJP.

      Instances For

        Small list/finite-sum bookkeeping #

        The semiring-generic dot product is defined by folding over coordinates, so we collect a small congruence lemma to rewrite the per-coordinate term.

        def Proofs.Autograd.Algebra.reluCorrect {α : Type} [CommSemiring α] [Max α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} :

        Correctness of ReLU’s backward rule, stated generically over α.

        We assume the extra structure needed to define ReLU and its derivative (Max, order, and decidable comparison). PyTorch analogue: torch.relu / torch.nn.functional.relu.

        Instances For

          Correctness of a linear layer’s backward rule (matrix–vector multiply), stated generically over α.

          This is purely algebraic: it relies only on semiring laws and the adjointness lemma for matrix multiplication in TensorAlgebra. PyTorch analogue: torch.nn.Linear’s linear map.

          Instances For

            Correctness of scaling by a constant: forward and backward are both x ↦ c • x.

            PyTorch analogue: c * x (with broadcasting aligned to shape).

            Instances For

              Correctness of pointwise multiplication by a fixed tensor rhs.

              PyTorch analogue: x * rhs (elementwise).

              Instances For
                def Proofs.Autograd.Algebra.mseSpecBasic {α : Type} [CommSemiring α] [Sub α] [Div α] {s : Spec.Shape} (predicted target : Spec.Tensor α s) :
                α

                Basic mean-squared error (MSE) scalar value:

                mse(predicted, target) = (∑ (predicted - target)^2) / size.

                This local definition is used only to define the loss OpSpec.

                Instances For
                  def Proofs.Autograd.Algebra.mseDerivSpecBasic {α : Type} [CommSemiring α] [Sub α] [Div α] {s : Spec.Shape} (predicted target : Spec.Tensor α s) :

                  Gradient of mse_spec_basic with respect to predicted, as a tensor of the same shape.

                  Up to conventions, this is 2*(predicted-target)/size.

                  Instances For

                    Correctness of mean-squared error loss (MSE) as an OpSpecCorrect.

                    This section assumes extra operations (Sub, Div, and coercions from naturals) because the MSE definition uses subtraction and division by Shape.size. PyTorch analogue: torch.nn.functional.mse_loss(reduction="mean") (up to normalization conventions).

                    Instances For