RealCorrectness #
Real-valued autograd correctness layer (proof-only).
This file does not talk about calculus (HasFDerivAt) yet. Instead it proves the standard
reverse-mode / forward-mode adjointness law (aka VJP/JVP duality) for a core set of ops:
⟪ JVP(x, dx), δ ⟫ = ⟪ dx, VJP(x, δ) ⟫
where ⟪·,·⟫ is the tensor dot-product (sum of elementwise products).
This is strong enough to justify the reverse-mode chain rule and to build a proved-correct layer
on top of Spec.OpSpec.compose.
Why this file exists (and why there is a second “algebraic” file) #
We keep two correctness developments:
real_correctness.lean(this file) specializes toℝand is the home for rules whose definitions/proofs genuinely depend on real-analytic structure (e.g. smooth activations andexp/log-style ops).semiring_correctness.leanis backend-generic over a typeαwith[CommSemiring α]. It is meant to instantiate to exact backends likeℚ, so it avoids assuming division, order, or transcendental functions unless an op explicitly requires them.
Keeping them separate prevents importing analysis-heavy assumptions into the semiring-generic proofs and keeps compilation dependencies smaller.
Technical difference #
- This file uses the
Spec.dot/Tensortheory fromNN/Proofs/Tensor/Basic.lean(specialized toℝ). - The semiring-generic file uses
TensorAlgebra.dotfromNN/Proofs/Tensor/Algebra.leanand keeps all statements polymorphic inαwith[CommSemiring α].
Runtime note #
- The runtime engine in
NN.Runtime.Autograd.Engineremains generic overαand works whenever the needed ops exist. Relating a concrete backend to these ℝ-proofs may require a separate semantic model (e.g. mapping toℝwith rounding error bounds for NeuralFloat).
PyTorch correspondence / citations #
- PyTorch AD background and conventions (VJP in reverse-mode): https://pytorch.org/docs/stable/autograd.html
- Custom VJP rules are analogous to implementing
torch.autograd.Function: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function
References (background):
- Baydin et al., “Automatic Differentiation in Machine Learning: a Survey”, JMLR 2018
(originally circulated as
arXiv:1502.05767). - Griewank & Walther, Evaluating Derivatives (2nd ed.), SIAM 2008 (reverse-mode AD foundations).
VJP/JVP adjointness for a unary op σ → τ.
Instances For
An OpSpec together with a matching JVP and a proof of VJP/JVP adjointness.
This is the “proved-correct local op” interface needed to build a sound reverse-mode tape.
- op : Spec.OpSpec ℝ σ τ
op.
- jvp : Spec.Tensor ℝ σ → Spec.Tensor ℝ σ → Spec.Tensor ℝ τ
jvp.
correct.
Instances For
Composition preserves VJP/JVP correctness (reverse-mode chain rule).
Informally: if f and g each satisfy the adjointness law, then g ∘ f does as well, with the
composed JVP and the composed VJP.
Instances For
A reusable adjointness identity #
Most elementwise ops have JVP of the form dx ⊙ f'(x) and VJP of the form f'(x) ⊙ δ.
The following lemma is the “commute elementwise factors under dot” fact that makes those proofs
one-liners.
Correctness of ReLU’s backward rule.
PyTorch analogue: torch.nn.functional.relu / torch.relu with its standard VJP.
Instances For
Correctness of sigmoid’s backward rule.
PyTorch analogue: torch.sigmoid.
Instances For
Correctness of tanh’s backward rule.
PyTorch analogue: torch.tanh.
Instances For
Correctness of softplus’s backward rule.
PyTorch analogue: torch.nn.functional.softplus.
Instances For
Correctness of SiLU’s backward rule.
PyTorch analogue: torch.nn.functional.silu, equivalently x * sigmoid(x).
Instances For
Correctness of tanh-approximate GELU's VJP/JVP adjointness rule.
This proves the linear-algebraic part of the gelu backward rule used by Transformer-style
feed-forward blocks: multiplying the upstream cotangent by the local derivative mask is adjoint to
multiplying the tangent by the same mask. The scalar calculus theorem for the full tanh
approximation is intentionally separate because it depends on a longer chain-rule proof through
tanh, sqrt, and the cubic inner polynomial.
PyTorch analogue: torch.nn.functional.gelu(..., approximate="tanh").
Instances For
Correctness of safe_log’s backward rule (a log with an ε safeguard).
PyTorch analogue: typically implemented as torch.log(torch.clamp(x, min=ε)) (or similar).
Instances For
Correctness of a smooth absolute value’s backward rule (a differentiable approximation to |x|).
PyTorch analogue: a custom smooth abs implemented via sqrt(x^2 + ε^2) or similar.
Instances For
Correctness of exp’s backward rule.
PyTorch analogue: torch.exp.
Instances For
Correctness of square's backward rule.
PyTorch analogue: torch.square; the local derivative is 2 * x.
Instances For
Correctness of ELU's VJP/JVP adjointness rule.
This is the algebraic half of the argument: once a local derivative mask is chosen, the VJP
elu'(x) ⊙ δ is adjoint to the JVP dx ⊙ elu'(x). The analytic differentiability theorem lives in
Proofs.elu_deriv_correct, which correctly excludes the kink at 0 for arbitrary alpha.
PyTorch analogue: torch.nn.functional.elu.
Instances For
Correctness of sinh's backward rule.
PyTorch analogue: torch.sinh; the local derivative is cosh.
Instances For
Correctness of cosh's backward rule.
PyTorch analogue: torch.cosh; the local derivative is sinh.
Instances For
Correctness of log's backward rule.
PyTorch analogue: torch.log.
Instances For
Correctness of a linear layer’s backward rule (matrix–vector multiply).
PyTorch analogue: torch.nn.Linear (restricted here to the “weights only” linear map).
Instances For
Correctness of sum (reduce-all) backward rule.
Informally: d/dx (sum x) = 1, so the VJP replicates the upstream scalar gradient into every entry.
PyTorch analogue: torch.sum (over all elements).