Optimizer integration for Runtime.Autograd #
This module is the training-loop side of autograd: it takes a gradient map produced by
Runtime.Autograd and applies parameter updates.
PyTorch analogy:
ParamTableis like an ordered list of parameters, but we key everything by a stableNatid (closer tostate_dictkeys than pointer identity).ParamGroupandOptimizerStatemirrortorch.optim.Optimizerparameter groups and state.LRScheduleris a small wrapper around our scheduler implementations, similar totorch.optim.lr_scheduler.*.
All updates are shape checked and implemented using the pure Spec tensor operators, so they
can be used in eager execution or lowered into the compiled IR.
Formula ownership:
- this file owns the heterogeneous parameter-table handling, parameter groups, lazy state maps, scheduler stepping, and PyTorch-style coupled weight decay at the training-loop boundary;
NN.Runtime.Optim.Optimizersowns the canonical per-tensor optimizer equations.
The important rule is: this file must not define a second public optimizer-formula surface. The
step implementation below constructs canonical optimizer states from the dynamic parameter-table
buffers and calls NN.Runtime.Optim.Optimizers directly. The only local algebra left here is
training-loop glue that is not represented by the canonical pure states, such as coupled
weight-decay preprocessing and PyTorch-style momentum dampening/Nesterov handling.
Parameter table #
This section provides the parameter registry used by the training loop.
Unlike PyTorch (where parameters are objects with identity), we use an explicit Nat id so that:
A single trainable parameter entry.
This is the Runtime.Autograd equivalent of a "parameter tensor" in PyTorch, except we make the
identifier explicit (id : Nat) so we can key gradients and optimizer state in pure maps.
- id : ℕ
Stable identifier used to key gradients and optimizer state.
Optional human-readable name (e.g. module path); used only for reporting/debugging.
- value : AnyTensor α
The parameter value, stored as an
AnyTensor(shape erased).
Instances For
A flat list of parameters used by the training loop.
Instances For
Constructors #
Create a ParamEntry from a typed tensor.
This is mostly a convenience for assembling a ParamTable from known-shaped tensors.
Instances For
List of ids for quick membership checks.
Instances For
Find a parameter entry by id.
Instances For
Get a typed tensor from the table, with shape checking.
Instances For
Replace a parameter entry value by id.
Instances For
Scheduler wrapper #
Learning-rate scheduler wrapper used by the training loop.
PyTorch analogy: this plays the role of torch.optim.lr_scheduler.* objects, except we keep the
state as an inductive value and expose a pure getLR/advance API.
- constant {α : Type} : Optim.ConstantScheduler α → LRScheduler α
- exponential {α : Type} : Optim.ExponentialDecayScheduler α → LRScheduler α
- step {α : Type} : Optim.StepDecayScheduler α → LRScheduler α
- cosine {α : Type} : Optim.CosineAnnealingScheduler α → LRScheduler α
- linearWarmup {α : Type} : Optim.LinearWarmupScheduler α → LRScheduler α
- warmupCosine {α : Type} : Optim.WarmupCosineScheduler α → LRScheduler α
- cyclic {α : Type} : Optim.CyclicScheduler α → LRScheduler α
- triangular {α : Type} : Optim.TriangularCycleScheduler α → LRScheduler α
- oneCycle {α : Type} : Optim.OneCycleScheduler α → LRScheduler α
- lrFinder {α : Type} : Optim.LRFinder α → LRScheduler α
- custom {α : Type} : (ℕ → α) → ℕ → LRScheduler α
Instances For
Read current learning rate from the scheduler state.
Instances For
Advance scheduler state by one step.
Instances For
Optimizer configuration #
Which optimizer update rule to apply.
PyTorch analogy: these correspond roughly to torch.optim.SGD, Adam, AdamW, etc.
- sgd : OptimizerKind
- momentum : OptimizerKind
- adagrad : OptimizerKind
- rmsprop : OptimizerKind
- adam : OptimizerKind
- adamw : OptimizerKind
- adadelta : OptimizerKind
Instances For
Instances For
Optimizer hyperparameters for a subset of parameters.
PyTorch analogy: this is a single entry in the optimizer's param-group list
(optimizer.param_groups).
Parameter ids that belong to this group.
- lr : α
Base learning rate (possibly overridden by
scheduleron each step). - weight_decay : α
L2 regularization coefficient (behavior depends on the optimizer kind; see AdamW).
- momentum : α
Momentum factor (SGD with momentum).
- dampening : α
Dampening for momentum updates.
- nesterov : Bool
Use Nesterov variant for momentum updates.
- beta1 : α
Adam beta1 parameter (exponential decay for the first moment).
- beta2 : α
Adam beta2 parameter (exponential decay for the second moment).
- epsilon : α
Numerical stability term used by adaptive optimizers.
- rho : α
"Rho" decay parameter for RMSProp/AdaDelta style optimizers.
- scheduler : Option (LRScheduler α)
Optional learning-rate scheduler for this group.
Instances For
Full optimizer state used by the training loop.
This mirrors PyTorch's optimizer state:
- a global step counter,
- hyperparameter groups,
- and per-parameter state buffers keyed by parameter id (
Nat).
- kind : OptimizerKind
Which update rule to apply on
step. - groups : List (ParamGroup α)
Parameter groups (hyperparameters + membership).
- step : ℕ
Global step counter (increments once per
step). - momentum_buf : Std.HashMap ℕ (AnyTensor α)
Momentum buffer (SGD with momentum / Nesterov), keyed by parameter id.
- m : Std.HashMap ℕ (AnyTensor α)
Adam first-moment estimate, keyed by parameter id.
- v : Std.HashMap ℕ (AnyTensor α)
Adam second-moment estimate, keyed by parameter id.
- acc : Std.HashMap ℕ (AnyTensor α)
Accumulator buffer (AdaGrad/RMSProp/AdaDelta), keyed by parameter id.
- acc2 : Std.HashMap ℕ (AnyTensor α)
Second accumulator buffer (AdaDelta), keyed by parameter id.
Instances For
A pure state snapshot for saving/restoring optimizer state.
PyTorch analogy: this is the data carried by optimizer.state_dict() (modulo naming/layout).
We use association lists instead of HashMap so the result is deterministic and easy to serialize.
- kind : OptimizerKind
Optimizer algorithm used to interpret the stored buffers.
- step : ℕ
Global optimizer step at the time the snapshot was taken.
- groups : List (ParamGroup α)
Parameter groups, including scheduler state and hyperparameters.
Momentum buffers keyed by parameter id.
Adam-family first-moment buffers keyed by parameter id.
Adam-family second-moment buffers keyed by parameter id.
AdaGrad/RMSProp/Adadelta accumulator buffers keyed by parameter id.
Adadelta second accumulator buffers keyed by parameter id.
Instances For
Serialize optimizer state to a pure record.
PyTorch analogy: this is the "export" step for state_dict().
Instances For
Restore optimizer state from a state dict.
PyTorch analogy: this is the "import" step for load_state_dict(...).
Instances For
Optimizer step #
Lookup a per-parameter state buffer, initializing it with zeros if absent.
This is used for momentum/Adam accumulator initialization (PyTorch does this lazily on first step).
Instances For
Shape-check and cast an optimizer state buffer to match the current parameter value.
This prevents silent shape mismatches when reloading a checkpoint into a model with different parameter shapes.
Instances For
Add an L2 regularization term to the gradient: g + weight_decay * param.
Note: this is the coupled weight decay used by classic SGD-style updates. For AdamW the integration step delegates to the canonical optimizer's decoupled update.
Instances For
Convert the training-loop Adam step number into the canonical optimizer state's previous step.
The public training helper receives the step being applied (1 for the first Adam/AdamW update).
NN.Runtime.Optim.Optimizers stores the previous step in the state and increments internally.
For direct calls with t = 0, we still return 0 so the helper stays total and behaves like a
first update rather than constructing a negative predecessor.
Instances For
Update each group's learning rate from its scheduler (if present) and advance the scheduler state.
This matches the common training-loop pattern: "read LR, then call scheduler.step()".
Instances For
Build a map from parameter id to its ParamGroup.
Fails if an id appears in multiple groups (PyTorch also disallows overlapping param groups).
Instances For
Apply one optimizer step to a parameter table.
Inputs:
optis the current optimizer state (including per-parameter buffers),paramsis the current parameter table,gradsmaps parameter ids to gradients (as produced by autograd).
Behavior:
- applies LR schedulers (if configured) per group,
- shape-checks gradients and state buffers against each parameter,
- updates per-parameter state buffers (momentum / Adam m,v / accumulators),
- returns the updated optimizer state and an updated parameter table.