TorchLean API

NN.Runtime.Autograd.Train.Optim

Optimizer integration for Runtime.Autograd #

This module is the training-loop side of autograd: it takes a gradient map produced by Runtime.Autograd and applies parameter updates.

PyTorch analogy:

All updates are shape checked and implemented using the pure Spec tensor operators, so they can be used in eager execution or lowered into the compiled IR.

Formula ownership:

The important rule is: this file must not define a second public optimizer-formula surface. The step implementation below constructs canonical optimizer states from the dynamic parameter-table buffers and calls NN.Runtime.Optim.Optimizers directly. The only local algebra left here is training-loop glue that is not represented by the canonical pure states, such as coupled weight-decay preprocessing and PyTorch-style momentum dampening/Nesterov handling.

Parameter table #

This section provides the parameter registry used by the training loop.

Unlike PyTorch (where parameters are objects with identity), we use an explicit Nat id so that:

A single trainable parameter entry.

This is the Runtime.Autograd equivalent of a "parameter tensor" in PyTorch, except we make the identifier explicit (id : Nat) so we can key gradients and optimizer state in pure maps.

  • id :

    Stable identifier used to key gradients and optimizer state.

  • name : Option String

    Optional human-readable name (e.g. module path); used only for reporting/debugging.

  • value : AnyTensor α

    The parameter value, stored as an AnyTensor (shape erased).

Instances For
    @[reducible, inline]

    A flat list of parameters used by the training loop.

    Instances For

      Constructors #

      Create a ParamEntry from a typed tensor.

      This is mostly a convenience for assembling a ParamTable from known-shaped tensors.

      Instances For

        List of ids for quick membership checks.

        Instances For

          Find a parameter entry by id.

          Instances For

            Get a typed tensor from the table, with shape checking.

            Instances For
              def Runtime.Autograd.Train.ParamTable.set {α : Type} (ps : ParamTable α) (id : ) (value : AnyTensor α) :

              Replace a parameter entry value by id.

              Instances For

                Scheduler wrapper #

                Learning-rate scheduler wrapper used by the training loop.

                PyTorch analogy: this plays the role of torch.optim.lr_scheduler.* objects, except we keep the state as an inductive value and expose a pure getLR/advance API.

                Instances For
                  def Runtime.Autograd.Train.LRScheduler.getLR {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] :
                  LRScheduler αα

                  Read current learning rate from the scheduler state.

                  Instances For

                    Advance scheduler state by one step.

                    Instances For

                      Optimizer configuration #

                      Which optimizer update rule to apply.

                      PyTorch analogy: these correspond roughly to torch.optim.SGD, Adam, AdamW, etc.

                      Instances For

                        Optimizer hyperparameters for a subset of parameters.

                        PyTorch analogy: this is a single entry in the optimizer's param-group list (optimizer.param_groups).

                        • params : List

                          Parameter ids that belong to this group.

                        • lr : α

                          Base learning rate (possibly overridden by scheduler on each step).

                        • weight_decay : α

                          L2 regularization coefficient (behavior depends on the optimizer kind; see AdamW).

                        • momentum : α

                          Momentum factor (SGD with momentum).

                        • dampening : α

                          Dampening for momentum updates.

                        • nesterov : Bool

                          Use Nesterov variant for momentum updates.

                        • beta1 : α

                          Adam beta1 parameter (exponential decay for the first moment).

                        • beta2 : α

                          Adam beta2 parameter (exponential decay for the second moment).

                        • epsilon : α

                          Numerical stability term used by adaptive optimizers.

                        • rho : α

                          "Rho" decay parameter for RMSProp/AdaDelta style optimizers.

                        • scheduler : Option (LRScheduler α)

                          Optional learning-rate scheduler for this group.

                        Instances For

                          Full optimizer state used by the training loop.

                          This mirrors PyTorch's optimizer state:

                          • a global step counter,
                          • hyperparameter groups,
                          • and per-parameter state buffers keyed by parameter id (Nat).
                          Instances For

                            A pure state snapshot for saving/restoring optimizer state.

                            PyTorch analogy: this is the data carried by optimizer.state_dict() (modulo naming/layout). We use association lists instead of HashMap so the result is deterministic and easy to serialize.

                            • Optimizer algorithm used to interpret the stored buffers.

                            • step :

                              Global optimizer step at the time the snapshot was taken.

                            • groups : List (ParamGroup α)

                              Parameter groups, including scheduler state and hyperparameters.

                            • momentum_buf : List ( × AnyTensor α)

                              Momentum buffers keyed by parameter id.

                            • m : List ( × AnyTensor α)

                              Adam-family first-moment buffers keyed by parameter id.

                            • v : List ( × AnyTensor α)

                              Adam-family second-moment buffers keyed by parameter id.

                            • acc : List ( × AnyTensor α)

                              AdaGrad/RMSProp/Adadelta accumulator buffers keyed by parameter id.

                            • acc2 : List ( × AnyTensor α)

                              Adadelta second accumulator buffers keyed by parameter id.

                            Instances For

                              Serialize optimizer state to a pure record.

                              PyTorch analogy: this is the "export" step for state_dict().

                              Instances For

                                Restore optimizer state from a state dict.

                                PyTorch analogy: this is the "import" step for load_state_dict(...).

                                Instances For

                                  Optimizer step #

                                  Create a zero-filled buffer with the same shape as a parameter value.

                                  Instances For

                                    Lookup a per-parameter state buffer, initializing it with zeros if absent.

                                    This is used for momentum/Adam accumulator initialization (PyTorch does this lazily on first step).

                                    Instances For

                                      Shape-check and cast an optimizer state buffer to match the current parameter value.

                                      This prevents silent shape mismatches when reloading a checkpoint into a model with different parameter shapes.

                                      Instances For
                                        def Runtime.Autograd.Train.Optim.addWeightDecay {α : Type} [Context α] {s : Spec.Shape} (param grad : Spec.Tensor α s) (weight_decay : α) :

                                        Add an L2 regularization term to the gradient: g + weight_decay * param.

                                        Note: this is the coupled weight decay used by classic SGD-style updates. For AdamW the integration step delegates to the canonical optimizer's decoupled update.

                                        Instances For

                                          Convert the training-loop Adam step number into the canonical optimizer state's previous step.

                                          The public training helper receives the step being applied (1 for the first Adam/AdamW update). NN.Runtime.Optim.Optimizers stores the previous step in the state and increments internally. For direct calls with t = 0, we still return 0 so the helper stays total and behaves like a first update rather than constructing a negative predecessor.

                                          Instances For
                                            def Runtime.Autograd.Train.Optim.updateGroupSchedulers {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (groups : List (ParamGroup α)) :

                                            Update each group's learning rate from its scheduler (if present) and advance the scheduler state.

                                            This matches the common training-loop pattern: "read LR, then call scheduler.step()".

                                            Instances For

                                              Build a map from parameter id to its ParamGroup.

                                              Fails if an id appears in multiple groups (PyTorch also disallows overlapping param groups).

                                              Instances For
                                                def Runtime.Autograd.Train.Optim.step {α : Type} [Context α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] (opt : OptimizerState α) (params : ParamTable α) (grads : Std.HashMap (AnyTensor α)) :

                                                Apply one optimizer step to a parameter table.

                                                Inputs:

                                                • opt is the current optimizer state (including per-parameter buffers),
                                                • params is the current parameter table,
                                                • grads maps parameter ids to gradients (as produced by autograd).

                                                Behavior:

                                                • applies LR schedulers (if configured) per group,
                                                • shape-checks gradients and state buffers against each parameter,
                                                • updates per-parameter state buffers (momentum / Adam m,v / accumulators),
                                                • returns the updated optimizer state and an updated parameter table.
                                                Instances For