Optimizers #
Optimizers for TorchLean runtime training.
This file implements the core math of common gradient-based optimizers as pure functions on
typed tensors Tensor α s.
Why “pure functions”?
In PyTorch, optimizers mutate parameters in-place and keep state in Python objects. In TorchLean, we want the update rule itself to be explicit and easy to reuse:
- eager demos can call the update directly,
- the runtime training engine can store state in maps keyed by parameter ids,
- and proofs can refer to the same update equations.
The intent is to mimic the standard textbook formulas closely. We do not try to reproduce every
implementation detail of torch.optim.* (e.g. foreach kernels, fused updates, or every optional
flag); those live at a different layer than the math we specify here.
Where this file sits in the stack:
- this file owns the scalar-polymorphic, per-tensor update equations;
NN.Runtime.Autograd.TorchLean.Optimlifts those equations to runtime parameter lists; andNN.API.Runtimeexposes ergonomicoptim.sgd,optim.adam, and related configuration helpers.
That separation is deliberate: the formula appears once, while runtime adapters and public API configuration can evolve independently around it.
Why each optimizer has its own State structure:
- Lean structures do not inherit from one another the way Python classes do.
- More importantly, optimizer state is not uniform: SGD stores only
lr, momentum SGD stores a buffer, Adam/AdamW store two moment buffers and a step counter, Adadelta stores gradient/update EMAs, and Muon/GaLore carry backend functions. - Keeping these as separate typed states makes impossible states unrepresentable. For example, an
SGD state cannot accidentally contain a stale Adam
vbuffer, and AdamW cannot forget its decoupledweight_decaycoefficient.
The generic abstraction lives one layer up:
Runtime.Autograd.TorchLean.Optim.Optimizerpackagesinit/stepfor shape-indexed parameter lists, like a typed analogue of a PyTorch optimizer object.Runtime.Autograd.Train.OptimizerStatehandles dynamic parameter groups and checkpoint-style maps for the training-loop API.
So this file intentionally favors small canonical state records over an inheritance hierarchy.
References (original algorithms / common variants):
- AdaGrad (Duchi–Hazan–Singer, 2011): https://jmlr.org/papers/v12/duchi11a.html
- RMSProp (Hinton lecture notes; widely used variant): https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
- Adam (Kingma–Ba, 2015): https://arxiv.org/abs/1412.6980
- AdamW / decoupled weight decay (Loshchilov–Hutter, 2019): https://arxiv.org/abs/1711.05101
- Adadelta (Zeiler, 2012): https://arxiv.org/abs/1212.5701
- SGD + momentum in deep learning practice (Sutskever et al., 2013): https://arxiv.org/abs/1301.4083
- GaLore / low-rank gradient projection (Zhao et al., 2024): https://arxiv.org/abs/2403.03507
- Muon-style momentum with orthogonalized matrix updates (Jordan et al., 2024): https://kellerjordan.github.io/posts/muon/
PyTorch references (for API/parameter naming):
torch.optimoverview: https://pytorch.org/docs/stable/optim.htmltorch.optim.SGD: https://pytorch.org/docs/stable/generated/torch.optim.SGD.htmltorch.optim.Adagrad: https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.htmltorch.optim.RMSprop: https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.htmltorch.optim.Adam: https://pytorch.org/docs/stable/generated/torch.optim.Adam.htmltorch.optim.AdamW: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.htmltorch.optim.Adadelta: https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html
Scalar exponentiation starts at 1.
Successor case for scalar exponentiation.
Shared utilities #
Momentum-style buffer update μ * buf + g.
Instances For
Elementwise “adaptive learning rate” tensor lr / (sqrt(denom) + ε).
This is shared by AdaGrad/RMSProp/Adam-style optimizers.
Instances For
SGD #
SGD state (per parameter tensor).
We only store the learning rate here.
- lr : α
Learning rate.
Instances For
Initialize SGD state.
The parameter tensor is unused; we keep it in the signature so optimizers share the same “init from parameters” calling convention.
Instances For
SGD initialization records exactly the requested learning rate.
One SGD step: p ← p - lr * g.
PyTorch analogy: the core of torch.optim.SGD without momentum/weight-decay extras.
Instances For
Momentum SGD #
Momentum SGD state (per parameter tensor).
We store a momentum buffer buf and a momentum coefficient μ.
Update rule:
This matches PyTorch's SGD momentum behavior when dampening = 0 and nesterov = false.
- lr : α
Learning rate.
- momentum : α
Momentum coefficient
μ. - buf : Spec.Tensor α s
Momentum buffer
buf.
Instances For
Initialize momentum SGD with a zero buffer.
Instances For
Momentum-SGD starts with a zero momentum buffer.
One momentum-SGD step (returns updated state and parameters).
Instances For
AdaGrad #
AdaGrad state (per parameter tensor).
We store an accumulator G of squared gradients (same shape as the parameters). The effective
step size is scaled by 1 / (sqrt(G) + ε).
- lr : α
Base learning rate.
- epsilon : α
Numerical stability constant
ε. - accumulator : Spec.Tensor α s
Accumulated squared gradients.
Instances For
Initialize AdaGrad with zero accumulator.
Instances For
AdaGrad starts with a zero squared-gradient accumulator.
One AdaGrad step (returns updated state and parameters).
Instances For
RMSProp #
RMSProp state (per parameter tensor).
We store an EMA of squared gradients (accumulator), often called square_avg in PyTorch code.
- lr : α
Learning rate.
- decay : α
Decay coefficient for the EMA of
g²(often calledalpha). - epsilon : α
Numerical stability constant
ε. - accumulator : Spec.Tensor α s
EMA of squared gradients.
Instances For
Initialize RMSProp with zero accumulator.
Instances For
RMSProp starts with a zero running average of squared gradients.
One RMSProp step (returns updated state and parameters).
Instances For
Adam #
Adam state (per parameter tensor).
We store first/second moment EMAs (m, v) and a step counter t used for bias correction.
- lr : α
Learning rate.
- beta1 : α
First moment decay
β₁. - beta2 : α
Second moment decay
β₂. - epsilon : α
Numerical stability constant
ε. - m : Spec.Tensor α s
First moment EMA.
- v : Spec.Tensor α s
Second moment EMA.
- t : ℕ
Step counter (used for bias correction).
Instances For
Initialize Adam with m = 0, v = 0, and t = 0.
Instances For
Adam starts at step 0.
Adam starts with a zero first-moment buffer.
Adam starts with a zero second-moment buffer.
One Adam step (returns updated state and parameters).
Equations (elementwise):
m ← β₁ m + (1-β₁) gv ← β₂ v + (1-β₂) g²m̂ ← m / (1-β₁ᵗ)v̂ ← v / (1-β₂ᵗ)p ← p - lr * m̂ / (sqrt(v̂) + ε)
The ε placement matches Kingma and Ba: it is added after sqrt(v̂).
Instances For
Adam increments its step counter by one on every update.
AdamW #
AdamW state (per parameter tensor).
AdamW is “Adam + decoupled weight decay”. The key point is that weight decay is applied as a separate parameter decay term rather than being folded into the gradient that feeds the moments.
- lr : α
Learning rate.
- beta1 : α
First moment decay
β₁. - beta2 : α
Second moment decay
β₂. - epsilon : α
Numerical stability constant
ε. - weight_decay : α
Weight decay coefficient
wd. - m : Spec.Tensor α s
First moment EMA.
- v : Spec.Tensor α s
Second moment EMA.
- t : ℕ
Step counter (used for bias correction).
Instances For
Initialize AdamW state for a parameter tensor (moments start at 0).
Instances For
AdamW initialization records the requested decoupled weight-decay coefficient.
AdamW starts at step 0.
One AdamW step (returns updated state and parameters).
We implement the decoupled form from the AdamW paper:
- update Adam moments using the raw gradient
g, - apply weight decay directly to the parameters (
p ← p - lr * wd * p), - then apply the Adam update.
This is the same single-step ordering used by torch.optim.AdamW.
Instances For
AdamW increments its step counter by one on every update.
Adadelta #
Adadelta state (per parameter tensor).
We store two EMAs:
- lr : α
Learning rate (often set to
1in some presentations; we keep it explicit). - rho : α
Decay coefficient
ρ. - epsilon : α
Numerical stability constant
ε. - v : Spec.Tensor α s
EMA of squared gradients.
- u : Spec.Tensor α s
EMA of squared updates.
Instances For
Initialize Adadelta state for a parameter tensor (EMAs start at 0).
Instances For
Adadelta starts with a zero squared-gradient EMA.
Adadelta starts with a zero squared-update EMA.
One Adadelta step (returns updated state and parameters).
Elementwise equations:
The ε placement is inside the RMS terms, matching Zeiler's Adadelta update.
Instances For
Projected / low-rank optimizers #
A shape-safe gradient projector.
GaLore-style training periodically builds a low-rank subspace for a large matrix parameter, projects the gradient into that subspace, runs a base optimizer there, and lifts the update back to the original parameter shape. This record is deliberately only the algebraic interface: the expensive policy that computes or refreshes the projector belongs to the runtime layer.
- project : Spec.Tensor α full → Spec.Tensor α low
Project a full gradient into the low-rank optimizer space.
- lift : Spec.Tensor α low → Spec.Tensor α full
Lift a low-rank update back to the full parameter shape.
Instances For
Identity projector, useful for tests and for the theorem that projected SGD reduces to SGD.
Instances For
GaLore-style projected SGD state for one tensor.
This is not a full GaLore implementation by itself: it specifies the update once a projector is available. A practical trainer still needs a refresh schedule and a way to build projectors for large matrix parameters.
- lr : α
Learning rate used after the gradient has been projected and lifted.
- projector : Projector α full low
Current gradient projector.
Instances For
One projected-SGD update: p ← p - lr * lift(project(g)).
Instances For
With the identity projector, projected SGD is exactly ordinary SGD.
This is the main sanity check for the GaLore extension point: adding a projection backend cannot silently change the base optimizer when the backend is the identity.
Muon-style orthogonalized momentum #
Orthogonalization backend for a matrix-shaped update.
Muon uses a momentum buffer and then replaces the raw momentum direction by an approximately orthogonalized update, commonly via Newton-Schulz iterations. TorchLean keeps this as an explicit backend so the pure update rule is testable before CUDA kernels are introduced.
- apply : Spec.Tensor α s → Spec.Tensor α s
Convert a momentum buffer into the direction used for the parameter update.
Instances For
The identity orthogonalizer; with this backend Muon reduces to momentum SGD.
Instances For
Per-parameter state for Muon-style momentum with an explicit orthogonalization backend.
- lr : α
Learning rate.
- momentum : α
Momentum coefficient.
- buf : Spec.Tensor α s
Momentum buffer.
- orthogonalizer : Orthogonalizer α s
Backend that turns the momentum buffer into the update direction.
Instances For
Initialize Muon-style state with a zero momentum buffer.
Instances For
One Muon-style update:
- update the momentum buffer,
- orthogonalize the buffer,
- subtract the scaled orthogonalized direction.
For actual Muon, use a matrix-shaped s and a Newton-Schulz orthogonalizer. The generic shape here
keeps the definition reusable for tests and for future batched matrix layouts.
Instances For
With the identity orthogonalizer, Muon's parameter update is exactly momentum SGD's parameter update.
The state records are different because Muon carries an orthogonalizer backend, but the parameter
direction is the same when that backend is id.