BackwardScale #
Backward (reverse-mode) scale propagation.
This optional module mirrors NN.Proofs.RuntimeApprox.Graph.BackwardApprox, but for scale bounds
(nonnegative bounds on linf_norm) rather than eps error bounds.
Use it alongside the backward approximation graph when you want to derive abs+rel tolerances for gradients/cotangents from both an eps error bound and a propagated magnitude bound.
Soundness condition for accumulating scale bounds under addition in a context.
Instances For
A reverse node augmented with forward+VJP scale bounds.
- forwardRuntime : TList α Γ → Spec.Tensor α τ
- bound : EList Γ → TList α Γ → Spec.SpecScalar
- sound (xS : TList Spec.SpecScalar Γ) (xR : TList α Γ) (eps : EList Γ) : approxCtx toSpec xS xR eps → approxT toSpec (self.forwardSpec xS) (self.forwardRuntime xR) (self.bound eps xR)
- vjpSpec : TList Spec.SpecScalar Γ → Spec.SpecTensor τ → TList Spec.SpecScalar Γ
- vjpRuntime : TList α Γ → Spec.Tensor α τ → TList α Γ
- vjpBound : EList Γ → TList α Γ → Spec.SpecScalar → Spec.Tensor α τ → EList Γ
- vjpSound (ctxS : TList Spec.SpecScalar Γ) (ctxR : TList α Γ) (epsCtx : EList Γ) (δS : Spec.SpecTensor τ) (δR : Spec.Tensor α τ) (epsδ : Spec.SpecScalar) : approxCtx toSpec ctxS ctxR epsCtx → approxT toSpec δS δR epsδ → approxCtx toSpec (self.vjpSpec ctxS δS) (self.vjpRuntime ctxR δR) (self.vjpBound epsCtx ctxR epsδ δR)
- fwdScaleSound (ctxS : TList Spec.SpecScalar Γ) (ctxR : TList α Γ) (epsCtx : EList Γ) (bCtx : BList Γ) : approxCtx toSpec ctxS ctxR epsCtx → scaleCtx toSpec ctxS ctxR bCtx → scaleT toSpec (self.forwardSpec ctxS) (self.forwardRuntime ctxR) (self.fwdScaleBound bCtx ctxR)
- vjpScaleBound : BList Γ → TList α Γ → NNReal → Spec.Tensor α τ → BList Γ
- vjpScaleSound (ctxS : TList Spec.SpecScalar Γ) (ctxR : TList α Γ) (epsCtx : EList Γ) (bCtx : BList Γ) (δS : Spec.SpecTensor τ) (δR : Spec.Tensor α τ) (bδ : NNReal) : approxCtx toSpec ctxS ctxR epsCtx → scaleCtx toSpec ctxS ctxR bCtx → scaleT toSpec δS δR bδ → scaleCtx toSpec (self.vjpSpec ctxS δS) (self.vjpRuntime ctxR δR) (self.vjpScaleBound bCtx ctxR bδ δR)
Instances For
Reverse-mode graph with scale-aware nodes.
- nil {α : Type} {toSpec : α → Spec.SpecScalar} {Γ : List Spec.Shape} : RevGraphScale toSpec Γ []
- snoc {α : Type} {toSpec : α → Spec.SpecScalar} {Γ ss : List Spec.Shape} {τ : Spec.Shape} : RevGraphScale toSpec Γ ss → RevNodeScale toSpec (Γ ++ ss) τ → RevGraphScale toSpec Γ (ss ++ [τ])
Instances For
Forget the scale annotations on nodes, producing an ordinary RevGraph.
Instances For
Convert a RevGraphScale into a FwdGraphScale by dropping the reverse-mode payload.
Instances For
Evaluate the forward pass on spec values, returning the extended context Γ ++ ss.
Instances For
Evaluate the forward pass on runtime values, returning the extended context Γ ++ ss.
Instances For
Forward-pass error bounds for all intermediate nodes, computed from input bounds epsIn.
Instances For
Forward-pass scale bounds for all intermediate nodes, computed from input bounds bIn.
Instances For
Backpropagate scale bounds through a RevGraphScale, analogous to RevGraph.backpropRuntime.