Tensor gradient utilities (spec layer) #
These are small, generic helpers that operate on gradient tensors:
- norm-based clipping (
clip_gradients_spec) - value-based clipping (
clip_by_value_spec) - percentile-based clipping (
clip_by_percentile_spec)
They are defined at the spec layer so they can be used both:
- in executable training examples (instantiated at
Float/ NF backends), and - in proofs (instantiated at
ℝ).
Why clipping utilities belong in the spec layer:
- Gradient clipping is part of the algorithmic definition of many training loops, not just an implementation detail. If we want to reason about "the training step we ran", we need clipping to be part of the pure model of that step.
- We also want to reuse the same clipping logic across scalar backends:
Floatfor executable runs, and proof-friendly scalars (ℝ,NF, etc.) for theorems and approximation statements.
Design note:
- These definitions are written for clarity and reuse across scalar backends. Backend-specific implementations (for example, fused kernels) belong in the runtime layer.
def
Spec.clipGradientsSpec
{α : Type}
[Context α]
[DecidableRel fun (x1 x2 : α) => x1 > x2]
{s : Shape}
(gradients : Tensor α s)
(max_norm : α)
:
Tensor α s
Clip gradients by L2 norm (global norm over all elements).
This implements the common "global norm clipping" used in many optimizers:
- compute
||g||_2 - if it exceeds
max_norm, rescalegso that||g||_2 = max_norm.
Implementation detail:
- We compare squared norms first so we only compute
sqrtin the clipping branch. - We treat
max_normas a magnitude, so we useabs max_normas the threshold.
Instances For
def
Spec.clipByPercentileSpec
{α : Type}
[Context α]
{s : Shape}
(gradients : Tensor α s)
(pct : ℕ)
[DecidableLT α]
:
Tensor α s
Clip gradients by percentile of absolute values.
This is a value clipping rule driven by the data:
- Flatten
abs(g)to an array. - Take the
pctpercentile (0..100) as a boundb. - Return
clamp(g, -b, b).
Notes:
- This definition sorts values, so it requires decidable comparison (
DecidableLT α). - In practice this is meant for executable scalars like
FloatorIEEE32Exec.
PyTorch analogy (conceptual): compute b = quantile(abs(g), pct/100) and clamp to [-b, b].