OpSpec #
Generic analytic (HasFDerivAt/fderiv) soundness for composed Spec.OpSpecs.
NN.Proofs.Autograd.FDeriv.Core proves the first end-to-end instance (a 2-layer MLP) by:
- proving
OpSpecCorrect(dot/JVP/VJP adjointness), and - identifying the JVP with the Fréchet derivative.
This file packages (2) as an extra field and shows it is closed under OpSpecCorrect.compose.
Result: once primitive ops have analytic JVP facts, any sequential OpSpec graph built by
composition gets the theorem:
backward x δ = VJP[forward, x] δ (after converting tensors ↔ Euclidean vectors).
Basic tensor/vector roundtrip #
Most analytic statements here are written in Euclidean space (Vec n) because Mathlib’s fderiv
and adjoint API lives there. The following lemma just re-exports the ofVecE/toVecE roundtrip in a
form that is convenient for rewriting.
A proved OpSpec (OpSpecCorrect) together with the analytic fact that its JVP is fderiv.
This is the “bridge object” that upgrades dot-level correctness (JVP/VJP adjointness) into an
actual HasFDerivAt statement about the forward function on Vec n.
PyTorch analogy: this corresponds to saying “the local backward rule is the transpose Jacobian of the true derivative” for a primitive op, so that composing ops yields correct global backward.
- correct : OpSpecCorrect (Spec.Shape.dim inDim Spec.Shape.scalar) (Spec.Shape.dim outDim Spec.Shape.scalar)
correct.
deriv.
- hasFDerivAt (xV : Vec inDim) : HasFDerivAt (fun (xV : Vec inDim) => toVecE (self.correct.op.forward (ofVecE xV))) (self.deriv xV) xV
has FDeriv At.
- jvp_eq (xV dxV : Vec inDim) : toVecE (self.correct.jvp (ofVecE xV) (ofVecE dxV)) = (self.deriv xV) dxV
jvp eq.
Instances For
The induced forward function on Euclidean vectors.
Instances For
Main analytic soundness statement for a single OpSpecFDerivCorrect:
backward x δ is the adjoint of the Fréchet derivative of the forward map, applied to δ.
This is the analytic justification for reverse-mode: it says the implemented VJP is the true Jacobian-transpose product.
Composition preserves analytic correctness (chain rule).
If f and g each have a correct fderiv identification of their JVP, then g ∘ f does too.
This is the key closure property used to scale from primitive ops to sequential models.