TorchLean API

NN.Runtime.Autograd.Torch.Core.Trainer

Torch Trainer Helpers #

Ops instances, parameter lists, and scalar trainer construction for the Torch-style runtime. This is the bridge from backend-generic model code to executable training loops.

@[reducible, inline]

Monad used for the eager Ops instance: read an Internal.EagerSession α and execute in IO.

This is the backend that makes Ops programs execute immediately by mutating a hidden runtime tape.

Instances For
    @[implicit_reducible]

    Ops instance for the eager Torch-style runtime.

    This interprets Ops primitives by immediately executing them against the hidden mutable tape in the current Internal.EagerSession.

    @[implicit_reducible]

    Ops instance for the compiled graph-building monad GraphM.

    This interprets Ops primitives by recording typed IR nodes (rather than executing immediately). See Runtime.Autograd.Compiled.GraphM and Torch.LinkedSession for how these graphs are later run.

    Heterogeneous list of trainable parameters, indexed by a list of shapes.

    This is the Torch front-end analogue of "a parameter vector" (like model.parameters() in PyTorch), but with shapes tracked at the type level.

    Instances For

      Materialize the SGD update v - lr * g in a single traversal.

      This is used by sgdStep_fast as a runtime-performance optimization to avoid building deep thunk chains when training for many steps.

      Instances For

        Allocate a fresh ParamList from an initial TList of parameter tensors.

        Each tensor becomes an IO.Ref so it can be updated by optimizer steps.

        Instances For

          Allocate a fresh ParamList from an initial TList of parameter tensors, with explicit requiresGrad flags.

          Returns an error when the flag list length does not match the parameter shape list length.

          Instances For

            Read the current parameter values as a TList aligned with the shape list.

            Instances For

              Read parameter values, synchronizing CUDA-resident mirrors first when necessary.

              Instances For

                Overwrite the current parameter values from a TList aligned with the shape list.

                Instances For
                  def Runtime.Autograd.Torch.ParamList.sgdStep {α : Type} [Context α] {ss : List Spec.Shape} :
                  ParamList α ss(lr : α) → TList α ssIO Unit

                  Apply an SGD step p := p - lr * g to each parameter that has requiresGrad = true.

                  gs must be aligned with the parameter shapes.

                  Instances For
                    def Runtime.Autograd.Torch.ParamList.sgdStepFast {α : Type} [Context α] {ss : List Spec.Shape} :
                    ParamList α ss(lr : α) → TList α ssIO Unit

                    Like sgdStep, but uses a fully materialized update (subScaleMaterialize) for speed.

                    This is a runtime performance knob; mathematically it is equivalent to sgdStep.

                    Instances For
                      structure Runtime.Autograd.Torch.ScalarTrainer (α : Type) (paramShapes inputShapes : List Spec.Shape) :

                      Bundle a scalar-loss training loop for a fixed parameter pack and input signature.

                      This is the low-level trainer object used by module-backed execution:

                      • forward computes a scalar loss,
                      • backward computes gradients w.r.t. parameters,
                      • step applies an optimizer update (typically SGD),
                      • getParams reads current parameter values.
                      • params : ParamList α paramShapes

                        Mutable trainable parameter pack.

                      • forward : Curried.Fn α inputShapes (IO (Spec.Tensor α Spec.Shape.scalar))

                        Compute the scalar loss for a curried input pack.

                      • backward : Curried.Fn α inputShapes (IO (TList α paramShapes))

                        Compute gradients aligned with paramShapes for a curried input pack.

                      • step : αCurried.Fn α inputShapes (IO Unit)

                        Apply one SGD-style update for a curried input pack.

                      • adamStep? : Option (ααααCurried.Fn α inputShapes (IO Unit))

                        Optional Adam update path.

                        In eager CUDA mode this is a device-gradient/device-moment update path. Other backends expose none and should use the generic optimizer wrappers.

                      • adamWStep? : Option (αααααCurried.Fn α inputShapes (IO Unit))

                        Optional AdamW update path.

                        In eager CUDA mode this is a device-gradient/device-moment update path with decoupled weight decay. Other backends expose none and should use the generic optimizer wrappers.

                      • getParams : IO (TList α paramShapes)

                        Read current parameter values, synchronizing device mirrors if needed.

                      Instances For

                        Extract gradients (as a typed TList) for a list of eager TensorRefs from a dense gradient array.

                        Instances For

                          Record all parameters as tape leaves in an eager session, returning their corresponding TensorRefs.

                          This is the eager analogue of "using" a parameter pack during a forward pass.

                          Instances For

                            Record all input tensors as tape leaves in an eager session, returning their corresponding TensorRefs.

                            Instances For
                              def Runtime.Autograd.Torch.scalarTrainer {α : Type} [Context α] [Internal.CudaBridge.TensorConv α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (opts : Options := { }) (initRequiresGrad : List Bool := List.replicate paramShapes.length true) (loss : {m : TypeType} → [Monad m] → [inst : Ops m α] → CurriedRef (fun (s : Spec.Shape) => Ops.Ref m α s) (paramShapes ++ inputShapes) (m (Ops.Ref m α Spec.Shape.scalar))) :
                              Curried.Fn α paramShapes (IO (ScalarTrainer α paramShapes inputShapes))

                              Build a ScalarTrainer from an initial parameter pack and a backend-generic loss definition.

                              loss is written once against the Ops interface over a concatenated context paramShapes ++ inputShapes. Depending on opts.backend, we either:

                              • compile the loss once (compiled backend), or
                              • execute it eagerly by building a runtime tape each step (eager backend).
                              Instances For