TorchLean API

NN.Runtime.Autograd.TorchLean.Module

Module #

TorchLean module wrappers with PyTorch-style ergonomics.

TorchLean already provides the core ingredients:

This file adds a thin “nn.Module-style” wrapper so users can:

Important: dtype selection is handled in NN.API.DType (because it picks the Lean type α). The module definitions here are polymorphic in α, so the same module can be:

Small helpers #

Cast a Float tensor to a backend scalar type α by mapping a scalar cast function.

This is mainly used to turn tensorND!-authored Float initializers into Float/IEEE32Exec/etc.

Instances For
    def Runtime.Autograd.TorchLean.Module.castTList {α : Type} (cast : Floatα) {ss : List Spec.Shape} :
    TList Float ssTList α ss

    List-shaped version of castTensor for TorchLean's TList parameter bundles.

    Instances For

      Runtime Float Initializers #

      Runtime initializer for a Float parameter.

      The usual ScalarModuleDef.initParams path stores initializers as typed Lean tensors. That is the right representation when the initial value itself is part of the Lean object being inspected. For large Float runs, it is better to allocate runtime storage from a compact initialization scheme and synchronize the host tensor only when parameters are explicitly read back.

      The design mirrors the storage-first APIs used by mainstream runtimes:

      • PyTorch exposes in-place initializers such as torch.nn.init.uniform_, torch.nn.init.xavier_uniform_, and torch.nn.init.kaiming_uniform_ for already-allocated tensors: https://pytorch.org/docs/stable/nn.init.html.
      • PyTorch's meta-device / to_empty path separates "module structure exists" from "real storage is materialized", after which users explicitly initialize parameters: https://docs.pytorch.org/docs/main/meta.html.

      TorchLean keeps the semantic parameter type (Tensor Float s) available, but this runtime path lets CPU/CUDA execution initialize real storage directly.

      • zeros : FloatInit

        Fill with zeros. PyTorch analogue: torch.nn.init.zeros_.

      • ones : FloatInit

        Fill with ones. PyTorch analogue: torch.nn.init.ones_.

      • uniform (lo hi : Float) (seed : := 0) : FloatInit

        Uniform distribution over [lo, hi), using TorchLean's deterministic runtime RNG.

      • xavierUniform (fanIn fanOut : ) (seed : := 0) : FloatInit

        Xavier/Glorot uniform with explicit fan-in and fan-out.

      • kaimingUniform (fanIn : ) (seed : := 0) : FloatInit

        Kaiming/He uniform with explicit fan-in.

      • flat (values : FloatArray) : FloatInit

        Exact row-major payload. Used for imported checkpoints or generated tensors.

      Instances For

        A shape-indexed initialization plan.

        This is the typed runtime-initialization API for modules with a known parameter shape list. It is the initialization analogue of TList: the type says there is exactly one initializer for each parameter shape, in the same order. That removes the annoying runtime failure mode where a plain list is one element too short or too long.

        The initializers themselves are runtime schemes rather than proof objects. The proof-facing story is still the ordinary Tensor Float s parameter value; this plan only controls how the executable Float runtime materializes those tensors on CPU or CUDA.

        Instances For

          Forget the shape index when interoperating with list-based callers.

          Instances For

            The type index is not decorative: forgetting a Plan ss to a list produces exactly ss.length initializers. This checked fact lets the runtime API avoid the usual "initializer list does not match parameter list" class of bugs once a plan has been built.

            Recover a shape-indexed plan from a plain list.

            List-based callers still enter through this boundary, but the runtime converts them immediately into the shape-indexed representation before touching any parameters.

            Instances For

              Product of a list of dimensions, used for convolutional receptive-field sizes.

              Instances For

                Infer (fanIn, fanOut) from a parameter shape using the common linear/conv convention.

                For a matrix shaped [out, in], this returns (in, out). For convolution-like weights shaped [outChannels, inChannels, k1, ..., kd], it returns:

                fanIn  = inChannels  * k1 * ... * kd
                fanOut = outChannels * k1 * ... * kd
                

                This is the same fan convention documented by PyTorch's Xavier/Kaiming initialization utilities.

                Instances For

                  Build a Xavier initializer by deriving fan-in/fan-out from a Linear/Conv-style weight shape.

                  Instances For

                    Build a Kaiming initializer by deriving fan-in from a Linear/Conv-style weight shape.

                    Instances For

                      Convenience initializer for a matrix weight stored as [outDim, inDim].

                      Instances For

                        Convenience initializer for a ReLU-style matrix weight stored as [outDim, inDim].

                        Instances For

                          Deterministic unit sample used by CPU/runtime initialization.

                          The CUDA path below uses Cuda.Buffer.randUniform, which is keyed by the same SplitMix64 family. Exact CPU/CUDA bit equality is not the contract here; reproducible initialization for a fixed path is. Tests that need exact CUDA RNG parity use the lower-level CUDA RNG stress tests.

                          Instances For

                            Scalar value generated by a FloatInit at a row-major flat index.

                            Instances For

                              Materialize an initializer as a host FloatArray.

                              CPU execution uses this path directly. CUDA uses it only when the initializer already is an exact flat payload; analytic initializers such as uniform/Xavier/Kaiming are created on the runtime side.

                              Instances For

                                Checked conversion to the current CUDA buffer API's UInt32 element count.

                                Instances For

                                  Allocate a CUDA buffer filled with U(lo, hi).

                                  The implementation keeps all element generation on the runtime side: first create a CUDA uniform buffer in [0,1), then perform lo + (hi-lo) * u with CUDA buffer ops.

                                  Instances For

                                    Allocate a CUDA buffer for a FloatInit.

                                    For analytic schemes (zeros, ones, uniform, xavierUniform, kaimingUniform), this avoids building a large nested Lean tensor. For .flat, the caller already supplied the exact payload, so we upload that payload directly.

                                    Instances For

                                      Materialize a runtime initializer as a normal host tensor. Used for CPU execution.

                                      Instances For

                                        Host slots for a parameter list before runtime initialization installs the real values.

                                        CUDA runtime initialization immediately replaces these with CUDA mirrors and marks the host values stale. These entries still give the existing Param type a valid host slot for later explicit readback.

                                        Instances For

                                          Apply a shape-indexed initialization plan to an already-created parameter list.

                                          The key point is that the shape list appears on both sides of the type:

                                          Torch.ParamList Float ss → RuntimeInit.Plan ss → IO Unit
                                          

                                          So Lean checks the bookkeeping that Python frameworks usually check at runtime: every parameter gets one initializer, and no extra initializer is silently ignored.

                                          Instances For

                                            Apply a runtime list of initializers after checking it against the parameter shapes.

                                            Plan ss is the typed form used by the initializer engine. This entrypoint is for places where the initializer list comes from outside Lean's typechecker, such as a checkpoint, JSON file, or CLI experiment.

                                            Instances For

                                              Scalar-loss module (training) #

                                              A scalar-loss module definition:

                                              • initParams is stored as Float constants (easy to write with tensorND!),
                                              • loss is polymorphic in the scalar backend (same code works for Float/IEEE32Exec/…).

                                              You can instantiate this definition as a ScalarModule under a chosen backend and dtype.

                                              Instances For

                                                Runtime module instance (the thing you "run").

                                                This wraps Torch.ScalarTrainer, but exposes a more Module-like set of methods.

                                                Instances For
                                                  def Runtime.Autograd.TorchLean.Module.ScalarModule.create {α : Type} [Context α] [DecidableEq Spec.Shape] [Torch.Internal.CudaBridge.TensorConv α] {paramShapes inputShapes : List Spec.Shape} (opts : Options := { }) (initRequiresGrad : List Bool := List.replicate paramShapes.length true) (loss : {m : TypeType} → [Monad m] → [inst : Ops m α] → CurriedRef (fun (s : Spec.Shape) => Torch.Ops.Ref m α s) (paramShapes ++ inputShapes) (m (Torch.Ops.Ref m α Spec.Shape.scalar))) (initParams : TList α paramShapes) :
                                                  IO (ScalarModule α paramShapes inputShapes)

                                                  Create a runtime scalar-loss module from an explicit loss program and initial parameter values.

                                                  This is the low-level constructor; public training code starts from a ScalarModuleDef and calls ScalarModuleDef.instantiate.

                                                  Instances For
                                                    def Runtime.Autograd.TorchLean.Module.ScalarModule.forward {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (xs : TList α inputShapes) :

                                                    Run the forward pass and return the scalar loss value.

                                                    Instances For
                                                      def Runtime.Autograd.TorchLean.Module.ScalarModule.backward {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (xs : TList α inputShapes) :
                                                      IO (TList α paramShapes)

                                                      Run one forward/backward pass and return gradients for all parameters.

                                                      Instances For
                                                        def Runtime.Autograd.TorchLean.Module.ScalarModule.step {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (lr : α) (xs : TList α inputShapes) :

                                                        Convenience "one-step SGD": compute gradients and apply an SGD update with learning rate lr.

                                                        Instances For
                                                          def Runtime.Autograd.TorchLean.Module.ScalarModule.initOptim {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (opt : Optim.Optimizer α paramShapes) :
                                                          IO opt.State

                                                          Initialize an optimizer state for this module's parameters.

                                                          Instances For
                                                            def Runtime.Autograd.TorchLean.Module.ScalarModule.stepWith {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (opt : Optim.Optimizer α paramShapes) (st : opt.State) (xs : TList α inputShapes) :
                                                            IO opt.State

                                                            Run one optimizer step using an explicit optimizer + state.

                                                            This mirrors a PyTorch training step:

                                                            1. compute gradients (backwardT)
                                                            2. update parameters via opt.step and return the new optimizer state
                                                            Instances For
                                                              def Runtime.Autograd.TorchLean.Module.ScalarModule.params {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) :
                                                              IO (TList α paramShapes)

                                                              Fetch the current parameter values as a shape-indexed list.

                                                              Instances For
                                                                def Runtime.Autograd.TorchLean.Module.ScalarModule.setParams {α : Type} [Context α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (ps : TList α paramShapes) :

                                                                Overwrite all parameter values.

                                                                Instances For
                                                                  def Runtime.Autograd.TorchLean.Module.ScalarModule.trainSGD {α : Type} [Context α] [DecidableEq Spec.Shape] [ToString α] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (lr : α) (steps : ) (samples : List (TList α inputShapes)) (logEvery : := 1) :

                                                                  Train with vanilla SGD for a fixed number of steps on a fixed list of samples.

                                                                  Instances For
                                                                    def Runtime.Autograd.TorchLean.Module.ScalarModule.trainWith {α : Type} [Context α] [DecidableEq Spec.Shape] [ToString α] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (opt : Optim.Optimizer α paramShapes) (st0 : opt.State) (steps : ) (samples : List (TList α inputShapes)) (logEvery : := 1) :
                                                                    IO opt.State

                                                                    Like trainSGD, but with an explicit optimizer + mutable optimizer state.

                                                                    Instances For
                                                                      def Runtime.Autograd.TorchLean.Module.ScalarModule.meanLoss {α : Type} [Context α] [DecidableEq Spec.Shape] [ToString α] {paramShapes inputShapes : List Spec.Shape} (m : ScalarModule α paramShapes inputShapes) (samples : List (TList α inputShapes)) :
                                                                      IO α

                                                                      Compute the mean loss over a list of samples (no parameter updates).

                                                                      Instances For
                                                                        def Runtime.Autograd.TorchLean.Module.ScalarModuleDef.instantiateWith {α : Type} [Context α] [DecidableEq Spec.Shape] [Torch.Internal.CudaBridge.TensorConv α] {paramShapes inputShapes : List Spec.Shape} (d : ScalarModuleDef paramShapes inputShapes) (cast : Floatα) (opts : Options) :
                                                                        IO (ScalarModule α paramShapes inputShapes)

                                                                        Instantiate a ScalarModuleDef by casting Float initializers to α and choosing Torch options.

                                                                        This is the most general constructor. The shorter instantiate entrypoint chooses standard runtime options before calling this function.

                                                                        Instances For
                                                                          def Runtime.Autograd.TorchLean.Module.ScalarModuleDef.instantiateFloatWithRuntimePlan {paramShapes inputShapes : List Spec.Shape} (d : ScalarModuleDef paramShapes inputShapes) (opts : Options) (plan : RuntimeInit.Plan paramShapes) :
                                                                          IO (ScalarModule Float paramShapes inputShapes)

                                                                          Instantiate a Float module using runtime parameter initializers.

                                                                          This is the runtime-initialized sibling of instantiateWith. Instead of first building every initial parameter as a full Lean tensor, it creates minimal zero host tensors and then applies a shape-indexed runtime plan to the module parameters. In CUDA mode those initializers allocate device buffers directly and mark the host copies stale; public parameter readback still synchronizes them through the existing CUDA mirror machinery.

                                                                          Instances For
                                                                            def Runtime.Autograd.TorchLean.Module.ScalarModuleDef.instantiateFloatWithRuntimeInit {paramShapes inputShapes : List Spec.Shape} (d : ScalarModuleDef paramShapes inputShapes) (opts : Options) (inits : List RuntimeInit.FloatInit) :
                                                                            IO (ScalarModule Float paramShapes inputShapes)

                                                                            Instantiate a Float module from a plain initializer list.

                                                                            This wrapper is useful at file/JSON boundaries. Internally it immediately checks the list against paramShapes and then delegates to instantiateFloatWithRuntimePlan, so the actual parameter mutation still goes through the shape-indexed path.

                                                                            Instances For
                                                                              def Runtime.Autograd.TorchLean.Module.ScalarModuleDef.instantiate {α : Type} [Context α] [DecidableEq Spec.Shape] [Torch.Internal.CudaBridge.TensorConv α] {paramShapes inputShapes : List Spec.Shape} (d : ScalarModuleDef paramShapes inputShapes) (cast : Floatα) (backend : Backend := Torch.Backend.eager) :
                                                                              IO (ScalarModule α paramShapes inputShapes)

                                                                              Convenience instantiator that chooses only the backend (.eager or .compiled).

                                                                              Instances For