TorchLean API

NN.Runtime.Optim.Optimizers

Optimizers #

Optimizers for TorchLean runtime training.

This file implements the core math of common gradient-based optimizers as pure functions on typed tensors Tensor α s.

Why “pure functions”?

In PyTorch, optimizers mutate parameters in-place and keep state in Python objects. In TorchLean, we want the update rule itself to be explicit and easy to reuse:

The intent is to mimic the standard textbook formulas closely. We do not try to reproduce every implementation detail of torch.optim.* (e.g. foreach kernels, fused updates, or every optional flag); those live at a different layer than the math we specify here.

Where this file sits in the stack:

That separation is deliberate: the formula appears once, while runtime adapters and public API configuration can evolve independently around it.

Why each optimizer has its own State structure:

The generic abstraction lives one layer up:

So this file intentionally favors small canonical state records over an inheritance hierarchy.

References (original algorithms / common variants):

PyTorch references (for API/parameter naming):

def Optim.scalarPowNat {α : Type} [One α] [Mul α] (x : α) :
α

Integer exponentiation for scalar optimizer coefficients.

We use an explicit Nat → α recursion instead of x ^ (n : Nat) because Context α provides Pow α α (for runtime scalar exponentiation), but not Pow α Nat.

Instances For
    @[simp]
    theorem Optim.scalarPowNat_zero {α : Type} [One α] [Mul α] (x : α) :

    Scalar exponentiation starts at 1.

    @[simp]
    theorem Optim.scalarPowNat_succ {α : Type} [One α] [Mul α] (x : α) (n : ) :
    scalarPowNat x (n + 1) = scalarPowNat x n * x

    Successor case for scalar exponentiation.

    Shared utilities #

    def Optim.OptimizerUtils.updateMomentumBuf {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (buf : Spec.Tensor α s) (momentum : α) (grads : Spec.Tensor α s) :

    Momentum-style buffer update μ * buf + g.

    Instances For
      def Optim.OptimizerUtils.mkAdaptiveLR {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr epsilon : α) (denom : Spec.Tensor α s) :

      Elementwise “adaptive learning rate” tensor lr / (sqrt(denom) + ε).

      This is shared by AdaGrad/RMSProp/Adam-style optimizers.

      Instances For

        SGD #

        structure Optim.SGD.State (α : Type) (s : Spec.Shape) :

        SGD state (per parameter tensor).

        We only store the learning rate here.

        • lr : α

          Learning rate.

        Instances For
          def Optim.SGD.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr : α) :
          Spec.Tensor α sState α s

          Initialize SGD state.

          The parameter tensor is unused; we keep it in the signature so optimizers share the same “init from parameters” calling convention.

          Instances For
            @[simp]
            theorem Optim.SGD.init_lr {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr : α) (params : Spec.Tensor α s) :
            (init lr params).lr = lr

            SGD initialization records exactly the requested learning rate.

            def Optim.SGD.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :

            One SGD step: p ← p - lr * g.

            PyTorch analogy: the core of torch.optim.SGD without momentum/weight-decay extras.

            Instances For

              Momentum SGD #

              Momentum SGD state (per parameter tensor).

              We store a momentum buffer buf and a momentum coefficient μ. Update rule:

              This matches PyTorch's SGD momentum behavior when dampening = 0 and nesterov = false.

              • lr : α

                Learning rate.

              • momentum : α

                Momentum coefficient μ.

              • buf : Spec.Tensor α s

                Momentum buffer buf.

              Instances For
                def Optim.MomentumSGD.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr momentum : α) :
                Spec.Tensor α sState α s

                Initialize momentum SGD with a zero buffer.

                Instances For
                  @[simp]
                  theorem Optim.MomentumSGD.init_buf {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr momentum : α) (params : Spec.Tensor α s) :
                  (init lr momentum params).buf = Spec.fill 0 s

                  Momentum-SGD starts with a zero momentum buffer.

                  def Optim.MomentumSGD.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                  State α s × Spec.Tensor α s

                  One momentum-SGD step (returns updated state and parameters).

                  Instances For

                    AdaGrad #

                    structure Optim.AdaGrad.State (α : Type) (s : Spec.Shape) :

                    AdaGrad state (per parameter tensor).

                    We store an accumulator G of squared gradients (same shape as the parameters). The effective step size is scaled by 1 / (sqrt(G) + ε).

                    • lr : α

                      Base learning rate.

                    • epsilon : α

                      Numerical stability constant ε.

                    • accumulator : Spec.Tensor α s

                      Accumulated squared gradients.

                    Instances For
                      def Optim.AdaGrad.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr epsilon : α) :
                      Spec.Tensor α sState α s

                      Initialize AdaGrad with zero accumulator.

                      Instances For
                        @[simp]
                        theorem Optim.AdaGrad.init_accumulator {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr epsilon : α) (params : Spec.Tensor α s) :
                        (init lr epsilon params).accumulator = Spec.fill 0 s

                        AdaGrad starts with a zero squared-gradient accumulator.

                        def Optim.AdaGrad.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                        State α s × Spec.Tensor α s

                        One AdaGrad step (returns updated state and parameters).

                        Instances For

                          RMSProp #

                          structure Optim.RMSProp.State (α : Type) (s : Spec.Shape) :

                          RMSProp state (per parameter tensor).

                          We store an EMA of squared gradients (accumulator), often called square_avg in PyTorch code.

                          • lr : α

                            Learning rate.

                          • decay : α

                            Decay coefficient for the EMA of (often called alpha).

                          • epsilon : α

                            Numerical stability constant ε.

                          • accumulator : Spec.Tensor α s

                            EMA of squared gradients.

                          Instances For
                            def Optim.RMSProp.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr decay epsilon : α) :
                            Spec.Tensor α sState α s

                            Initialize RMSProp with zero accumulator.

                            Instances For
                              @[simp]
                              theorem Optim.RMSProp.init_accumulator {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr decay epsilon : α) (params : Spec.Tensor α s) :
                              (init lr decay epsilon params).accumulator = Spec.fill 0 s

                              RMSProp starts with a zero running average of squared gradients.

                              def Optim.RMSProp.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                              State α s × Spec.Tensor α s

                              One RMSProp step (returns updated state and parameters).

                              Instances For

                                Adam #

                                structure Optim.Adam.State (α : Type) (s : Spec.Shape) :

                                Adam state (per parameter tensor).

                                We store first/second moment EMAs (m, v) and a step counter t used for bias correction.

                                • lr : α

                                  Learning rate.

                                • beta1 : α

                                  First moment decay β₁.

                                • beta2 : α

                                  Second moment decay β₂.

                                • epsilon : α

                                  Numerical stability constant ε.

                                • m : Spec.Tensor α s

                                  First moment EMA.

                                • v : Spec.Tensor α s

                                  Second moment EMA.

                                • t :

                                  Step counter (used for bias correction).

                                Instances For
                                  def Optim.Adam.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr beta1 beta2 epsilon : α) :
                                  Spec.Tensor α sState α s

                                  Initialize Adam with m = 0, v = 0, and t = 0.

                                  Instances For
                                    @[simp]
                                    theorem Optim.Adam.init_t {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr beta1 beta2 epsilon : α) (params : Spec.Tensor α s) :
                                    (init lr beta1 beta2 epsilon params).t = 0

                                    Adam starts at step 0.

                                    @[simp]
                                    theorem Optim.Adam.init_m {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr beta1 beta2 epsilon : α) (params : Spec.Tensor α s) :
                                    (init lr beta1 beta2 epsilon params).m = Spec.fill 0 s

                                    Adam starts with a zero first-moment buffer.

                                    @[simp]
                                    theorem Optim.Adam.init_v {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr beta1 beta2 epsilon : α) (params : Spec.Tensor α s) :
                                    (init lr beta1 beta2 epsilon params).v = Spec.fill 0 s

                                    Adam starts with a zero second-moment buffer.

                                    def Optim.Adam.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                    State α s × Spec.Tensor α s

                                    One Adam step (returns updated state and parameters).

                                    Equations (elementwise):

                                    • m ← β₁ m + (1-β₁) g
                                    • v ← β₂ v + (1-β₂) g²
                                    • m̂ ← m / (1-β₁ᵗ)
                                    • v̂ ← v / (1-β₂ᵗ)
                                    • p ← p - lr * m̂ / (sqrt(v̂) + ε)

                                    The ε placement matches Kingma and Ba: it is added after sqrt(v̂).

                                    Instances For
                                      @[simp]
                                      theorem Optim.Adam.update_t {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                      (update state params grads).1.t = state.t + 1

                                      Adam increments its step counter by one on every update.

                                      AdamW #

                                      structure Optim.AdamW.State (α : Type) (s : Spec.Shape) :

                                      AdamW state (per parameter tensor).

                                      AdamW is “Adam + decoupled weight decay”. The key point is that weight decay is applied as a separate parameter decay term rather than being folded into the gradient that feeds the moments.

                                      • lr : α

                                        Learning rate.

                                      • beta1 : α

                                        First moment decay β₁.

                                      • beta2 : α

                                        Second moment decay β₂.

                                      • epsilon : α

                                        Numerical stability constant ε.

                                      • weight_decay : α

                                        Weight decay coefficient wd.

                                      • m : Spec.Tensor α s

                                        First moment EMA.

                                      • v : Spec.Tensor α s

                                        Second moment EMA.

                                      • t :

                                        Step counter (used for bias correction).

                                      Instances For
                                        def Optim.AdamW.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr weight_decay beta1 beta2 epsilon : α) :
                                        Spec.Tensor α sState α s

                                        Initialize AdamW state for a parameter tensor (moments start at 0).

                                        Instances For
                                          @[simp]
                                          theorem Optim.AdamW.init_weight_decay {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr weightDecay beta1 beta2 epsilon : α) (params : Spec.Tensor α s) :
                                          (init lr weightDecay beta1 beta2 epsilon params).weight_decay = weightDecay

                                          AdamW initialization records the requested decoupled weight-decay coefficient.

                                          @[simp]
                                          theorem Optim.AdamW.init_t {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr weightDecay beta1 beta2 epsilon : α) (params : Spec.Tensor α s) :
                                          (init lr weightDecay beta1 beta2 epsilon params).t = 0

                                          AdamW starts at step 0.

                                          def Optim.AdamW.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                          State α s × Spec.Tensor α s

                                          One AdamW step (returns updated state and parameters).

                                          We implement the decoupled form from the AdamW paper:

                                          • update Adam moments using the raw gradient g,
                                          • apply weight decay directly to the parameters (p ← p - lr * wd * p),
                                          • then apply the Adam update.

                                          This is the same single-step ordering used by torch.optim.AdamW.

                                          Instances For
                                            @[simp]
                                            theorem Optim.AdamW.update_t {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                            (update state params grads).1.t = state.t + 1

                                            AdamW increments its step counter by one on every update.

                                            Adadelta #

                                            structure Optim.Adadelta.State (α : Type) (s : Spec.Shape) :

                                            Adadelta state (per parameter tensor).

                                            We store two EMAs:

                                            • v: EMA of squared gradients,
                                            • u: EMA of squared updates.
                                            • lr : α

                                              Learning rate (often set to 1 in some presentations; we keep it explicit).

                                            • rho : α

                                              Decay coefficient ρ.

                                            • epsilon : α

                                              Numerical stability constant ε.

                                            • v : Spec.Tensor α s

                                              EMA of squared gradients.

                                            • u : Spec.Tensor α s

                                              EMA of squared updates.

                                            Instances For
                                              def Optim.Adadelta.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr rho epsilon : α) :
                                              Spec.Tensor α sState α s

                                              Initialize Adadelta state for a parameter tensor (EMAs start at 0).

                                              Instances For
                                                @[simp]
                                                theorem Optim.Adadelta.init_v {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr rho epsilon : α) (params : Spec.Tensor α s) :
                                                (init lr rho epsilon params).v = Spec.fill 0 s

                                                Adadelta starts with a zero squared-gradient EMA.

                                                @[simp]
                                                theorem Optim.Adadelta.init_u {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr rho epsilon : α) (params : Spec.Tensor α s) :
                                                (init lr rho epsilon params).u = Spec.fill 0 s

                                                Adadelta starts with a zero squared-update EMA.

                                                def Optim.Adadelta.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                                State α s × Spec.Tensor α s

                                                One Adadelta step (returns updated state and parameters).

                                                Elementwise equations:

                                                • v ← ρ v + (1-ρ) g²
                                                • Δp ← - lr * (sqrt(u + ε) / sqrt(v + ε)) ⊙ g
                                                • p ← p + Δp
                                                • u ← ρ u + (1-ρ) (Δp)²

                                                The ε placement is inside the RMS terms, matching Zeiler's Adadelta update.

                                                Instances For

                                                  Projected / low-rank optimizers #

                                                  structure Optim.GaLore.Projector (α : Type) (full low : Spec.Shape) :

                                                  A shape-safe gradient projector.

                                                  GaLore-style training periodically builds a low-rank subspace for a large matrix parameter, projects the gradient into that subspace, runs a base optimizer there, and lifts the update back to the original parameter shape. This record is deliberately only the algebraic interface: the expensive policy that computes or refreshes the projector belongs to the runtime layer.

                                                  Instances For

                                                    Identity projector, useful for tests and for the theorem that projected SGD reduces to SGD.

                                                    Instances For
                                                      structure Optim.GaLore.SGDState (α : Type) (full low : Spec.Shape) :

                                                      GaLore-style projected SGD state for one tensor.

                                                      This is not a full GaLore implementation by itself: it specifies the update once a projector is available. A practical trainer still needs a refresh schedule and a way to build projectors for large matrix parameters.

                                                      • lr : α

                                                        Learning rate used after the gradient has been projected and lifted.

                                                      • projector : Projector α full low

                                                        Current gradient projector.

                                                      Instances For
                                                        def Optim.GaLore.projectedSGDUpdate {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {full low : Spec.Shape} (state : SGDState α full low) (params grads : Spec.Tensor α full) :
                                                        Spec.Tensor α full

                                                        One projected-SGD update: p ← p - lr * lift(project(g)).

                                                        Instances For
                                                          theorem Optim.GaLore.projectedSGDUpdate_identity_eq_sgd {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr : α) (params grads : Spec.Tensor α s) :
                                                          projectedSGDUpdate { lr := lr, projector := identityProjector } params grads = SGD.update { lr := lr } params grads

                                                          With the identity projector, projected SGD is exactly ordinary SGD.

                                                          This is the main sanity check for the GaLore extension point: adding a projection backend cannot silently change the base optimizer when the backend is the identity.

                                                          Muon-style orthogonalized momentum #

                                                          Orthogonalization backend for a matrix-shaped update.

                                                          Muon uses a momentum buffer and then replaces the raw momentum direction by an approximately orthogonalized update, commonly via Newton-Schulz iterations. TorchLean keeps this as an explicit backend so the pure update rule is testable before CUDA kernels are introduced.

                                                          • apply : Spec.Tensor α sSpec.Tensor α s

                                                            Convert a momentum buffer into the direction used for the parameter update.

                                                          Instances For

                                                            The identity orthogonalizer; with this backend Muon reduces to momentum SGD.

                                                            Instances For
                                                              structure Optim.Muon.State (α : Type) (s : Spec.Shape) :

                                                              Per-parameter state for Muon-style momentum with an explicit orthogonalization backend.

                                                              • lr : α

                                                                Learning rate.

                                                              • momentum : α

                                                                Momentum coefficient.

                                                              • buf : Spec.Tensor α s

                                                                Momentum buffer.

                                                              • orthogonalizer : Orthogonalizer α s

                                                                Backend that turns the momentum buffer into the update direction.

                                                              Instances For
                                                                def Optim.Muon.init {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr momentum : α) (orthogonalizer : Orthogonalizer α s) :
                                                                Spec.Tensor α sState α s

                                                                Initialize Muon-style state with a zero momentum buffer.

                                                                Instances For
                                                                  def Optim.Muon.update {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (state : State α s) (params grads : Spec.Tensor α s) :
                                                                  State α s × Spec.Tensor α s

                                                                  One Muon-style update:

                                                                  • update the momentum buffer,
                                                                  • orthogonalize the buffer,
                                                                  • subtract the scaled orthogonalized direction.

                                                                  For actual Muon, use a matrix-shaped s and a Newton-Schulz orthogonalizer. The generic shape here keeps the definition reusable for tests and for future batched matrix layouts.

                                                                  Instances For
                                                                    theorem Optim.Muon.update_identity_param_eq_momentumSGD {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Spec.Shape} (lr momentum : α) (buf params grads : Spec.Tensor α s) :
                                                                    (update { lr := lr, momentum := momentum, buf := buf, orthogonalizer := identityOrthogonalizer } params grads).2 = (MomentumSGD.update { lr := lr, momentum := momentum, buf := buf } params grads).2

                                                                    With the identity orthogonalizer, Muon's parameter update is exactly momentum SGD's parameter update.

                                                                    The state records are different because Muon carries an orthogonalizer backend, but the parameter direction is the same when that backend is id.