TorchLean API

NN.Proofs.Autograd.Core.RealCorrectness

RealCorrectness #

Real-valued autograd correctness layer (proof-only).

This file does not talk about calculus (HasFDerivAt) yet. Instead it proves the standard reverse-mode / forward-mode adjointness law (aka VJP/JVP duality) for a core set of ops:

⟪ JVP(x, dx), δ ⟫ = ⟪ dx, VJP(x, δ) ⟫

where ⟪·,·⟫ is the tensor dot-product (sum of elementwise products).

This is strong enough to justify the reverse-mode chain rule and to build a proved-correct layer on top of Spec.OpSpec.compose.

Why this file exists (and why there is a second “algebraic” file) #

We keep two correctness developments:

Keeping them separate prevents importing analysis-heavy assumptions into the semiring-generic proofs and keeps compilation dependencies smaller.

Technical difference #

Runtime note #

PyTorch correspondence / citations #

References (background):

def Proofs.Autograd.VJPCorrect {σ τ : Spec.Shape} (_forward : Spec.Tensor σSpec.Tensor τ) (jvp : Spec.Tensor σSpec.Tensor σSpec.Tensor τ) (vjp : Spec.Tensor σSpec.Tensor τSpec.Tensor σ) :

VJP/JVP adjointness for a unary op σ → τ.

Instances For

    An OpSpec together with a matching JVP and a proof of VJP/JVP adjointness.

    This is the “proved-correct local op” interface needed to build a sound reverse-mode tape.

    Instances For

      Composition preserves VJP/JVP correctness (reverse-mode chain rule).

      Informally: if f and g each satisfy the adjointness law, then g ∘ f does as well, with the composed JVP and the composed VJP.

      Instances For

        A reusable adjointness identity #

        Most elementwise ops have JVP of the form dx ⊙ f'(x) and VJP of the form f'(x) ⊙ δ. The following lemma is the “commute elementwise factors under dot” fact that makes those proofs one-liners.

        Correctness of ReLU’s backward rule.

        PyTorch analogue: torch.nn.functional.relu / torch.relu with its standard VJP.

        Instances For

          Correctness of sigmoid’s backward rule.

          PyTorch analogue: torch.sigmoid.

          Instances For

            Correctness of tanh’s backward rule.

            PyTorch analogue: torch.tanh.

            Instances For

              Correctness of softplus’s backward rule.

              PyTorch analogue: torch.nn.functional.softplus.

              Instances For

                Correctness of SiLU’s backward rule.

                PyTorch analogue: torch.nn.functional.silu, equivalently x * sigmoid(x).

                Instances For

                  Correctness of tanh-approximate GELU's VJP/JVP adjointness rule.

                  This proves the linear-algebraic part of the gelu backward rule used by Transformer-style feed-forward blocks: multiplying the upstream cotangent by the local derivative mask is adjoint to multiplying the tangent by the same mask. The scalar calculus theorem for the full tanh approximation is intentionally separate because it depends on a longer chain-rule proof through tanh, sqrt, and the cubic inner polynomial.

                  PyTorch analogue: torch.nn.functional.gelu(..., approximate="tanh").

                  Instances For

                    Correctness of safe_log’s backward rule (a log with an ε safeguard).

                    PyTorch analogue: typically implemented as torch.log(torch.clamp(x, min=ε)) (or similar).

                    Instances For

                      Correctness of a smooth absolute value’s backward rule (a differentiable approximation to |x|).

                      PyTorch analogue: a custom smooth abs implemented via sqrt(x^2 + ε^2) or similar.

                      Instances For

                        Correctness of exp’s backward rule.

                        PyTorch analogue: torch.exp.

                        Instances For

                          Correctness of square's backward rule.

                          PyTorch analogue: torch.square; the local derivative is 2 * x.

                          Instances For
                            noncomputable def Proofs.Autograd.eluCorrect {s : Spec.Shape} (eluAlpha : ) :

                            Correctness of ELU's VJP/JVP adjointness rule.

                            This is the algebraic half of the argument: once a local derivative mask is chosen, the VJP elu'(x) ⊙ δ is adjoint to the JVP dx ⊙ elu'(x). The analytic differentiability theorem lives in Proofs.elu_deriv_correct, which correctly excludes the kink at 0 for arbitrary alpha.

                            PyTorch analogue: torch.nn.functional.elu.

                            Instances For

                              Correctness of sinh's backward rule.

                              PyTorch analogue: torch.sinh; the local derivative is cosh.

                              Instances For

                                Correctness of cosh's backward rule.

                                PyTorch analogue: torch.cosh; the local derivative is sinh.

                                Instances For

                                  Correctness of log's backward rule.

                                  PyTorch analogue: torch.log.

                                  Instances For

                                    Correctness of a linear layer’s backward rule (matrix–vector multiply).

                                    PyTorch analogue: torch.nn.Linear (restricted here to the “weights only” linear map).

                                    Instances For

                                      Correctness of sum (reduce-all) backward rule.

                                      Informally: d/dx (sum x) = 1, so the VJP replicates the upstream scalar gradient into every entry.

                                      PyTorch analogue: torch.sum (over all elements).

                                      Instances For