TorchLean API

NN.API.Common.Core

NN.API.Common #

Shared helpers used across TorchLean workflows and entrypoints: typed tensor constructors, casting between scalar backends, and shared Except/IO utilities.

Common Helpers For Workflows And Entry Points #

This module contains helper functions that show up repeatedly in executable workflows.

What belongs here:

What does not belong here:

PyTorch Mapping Notes #

If you're coming from PyTorch, this module plays the role of the shared utility layer often written around:

The main difference is that TorchLean tracks shapes in types, so tensor constructors here return Except String ... rather than silently reshaping.

def NN.API.Common.castTensor {α : Type} (cast : Floatα) {s : Spec.Shape} (t : Spec.Tensor Float s) :

Cast a tensor elementwise while preserving its type-level shape.

Instances For
    def NN.API.Common.orThrow {α : Type} (tag : String := "Example") :
    Except String αIO α

    Convert an Except String α into an IO α, raising a tagged userError on failure.

    This is handy in main functions where we want the process to exit with a readable message.

    Instances For

      Standard location for model-example training logs under data/model_zoo.

      Instances For
        def NN.API.Common.check (tag msg : String) (b : Bool) :

        Fail with a tagged userError if a boolean condition is false.

        Executable workflows use this for named precondition checks such as Common.check exeName "loss finite" (loss == loss).

        Instances For

          Validate that a natural-number CLI flag is positive.

          Instances For
            def NN.API.Common.resolvePositiveNatFlag (exeName flag : String) (value? : Option ) (default : ) :

            Resolve an optional natural-number CLI flag against a default and require that the result is strictly positive.

            Example parsers use this helper for the common "optional flag + default + positivity check" case instead of restating the same getD / requirePositiveNatFlag sequence.

            Instances For

              Write a prepared TrainLog JSON artifact and report the file path.

              Instances For

                Write a prepared TrainLog to a destination that may be disabled.

                Instances For
                  def NN.API.Common.writeBeforeAfterLossLog (path : System.FilePath) (title : String) (steps : ) (loss0 loss1 : Float) (notes : Array String := #[]) :

                  Write a standard JSON training artifact for routines that record an initial and final loss.

                  The function uses Runtime.Training.TrainLog.beforeAfterLoss and the stable TrainLog JSON format. The output schema is independent of the model, dataset, and runtime backend.

                  Instances For
                    def NN.API.Common.writeBeforeAfterLossLogTo (dest : Runtime.Training.LogDestination) (title : String) (steps : ) (loss0 loss1 : Float) (notes : Array String := #[]) :

                    Write a before/after loss log to an explicit logging destination.

                    LogDestination.disabled is a no-op, mirroring wandb disabled for runs where metrics should stay on stdout only.

                    Instances For

                      Print the standard before/after loss summary returned by fit helpers.

                      Instances For

                        First and last point of a scalar training curve, ready for summaries.

                        • finalStep :

                          Step used for the final point.

                        • first : Float

                          First recorded metric value.

                        • last : Float

                          Last recorded metric value.

                        Instances For

                          Read the first and last values from a scalar curve.

                          If curve.steps is empty, the final step falls back to the last value index. Empty value arrays are reported as errors instead of printing a fake 0.0 summary.

                          Instances For

                            Require first/last scalar values from a training curve, or raise a user-facing error.

                            Instances For

                              Print the standard first/last loss summary for a scalar training curve.

                              Instances For
                                def NN.API.Common.writeCurveLog (path : System.FilePath) (title : String) (curve : Runtime.Training.Curve) (seriesName : String := "loss") (notes : Array String := #[]) :

                                Write a one-series scalar curve as a standard TrainLog JSON artifact.

                                Instances For
                                  def NN.API.Common.writeCurveLogTo (dest : Runtime.Training.LogDestination) (title : String) (curve : Runtime.Training.Curve) (seriesName : String := "loss") (notes : Array String := #[]) :

                                  Write a one-series scalar curve to an explicit logging destination.

                                  Instances For

                                    Common CLI result for training commands that accept --steps, --batch-size, and --log.

                                    • steps :

                                      Number of optimizer updates.

                                    • batchSize :

                                      Number of samples consumed by one public in-memory training step.

                                    • Logging destination. Use --log false / off / none to disable.

                                    • logPath : System.FilePath

                                      Path where the JSON TrainLog should be written.

                                    • cudaMemWatch :

                                      CUDA allocator telemetry cadence shared by fixed-sample and custom training loops.

                                    Instances For
                                      def NN.API.Common.parseLoggedTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (allowZeroSteps : Bool := false) :

                                      Parse common training flags: positive --steps, positive optional --batch-size, plus optional --log.

                                      --log <path> writes the standard local JSON artifact. --log false, --log off, or --log none disables artifact writing while preserving the parsed step count.

                                      Instances For

                                        Training flags shared by runnable model examples.

                                        This covers the knobs almost every example needs: --steps, --log, CUDA memory watching, and --lr. Model files should reuse this record and add only flags that change that model's actual behavior, such as text generation settings or evaluation probes.

                                        Instances For
                                          def NN.API.Common.parseModelTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (defaultLr : Float := 1e-3) (allowZeroSteps : Bool := false) :

                                          Parse the standard model-training flags: --steps, --log, --lr, and CUDA telemetry.

                                          Instances For

                                            Model-training flags plus an RNG/data-order seed.

                                            Use this when the command needs reproducible initialization, synthetic data, or shuffled row order in addition to the standard training flags.

                                            Instances For
                                              def NN.API.Common.parseSeededModelTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (defaultLr : Float := 1e-3) (allowZeroSteps : Bool := false) :

                                              Parse the standard model-training flags together with --seed.

                                              Examples use this when model initialization, synthetic data, or row order needs a reproducible seed.

                                              Instances For

                                                Progress Cadence Helpers #

                                                def NN.API.Common.shouldLogStep (logEvery done : ) :

                                                Return whether a completed training step should emit a progress report.

                                                The convention is shared across example trainers: logEvery = 0 disables progress output; otherwise we log at exact multiples of the completed-step count.

                                                Instances For

                                                  CUDA Memory Watch Helpers #

                                                  Choose a CUDA memory-watch cadence for public examples.

                                                  Users can pass --cuda-mem-watch N to choose an exact cadence. When no cadence is supplied, long CUDA runs sample about ten times over the requested training horizon. Short runs and CPU runs stay quiet by default, so the examples do not print allocator telemetry unless it is likely to be useful.

                                                  Instances For

                                                    Standard TrainLog note for the effective CUDA memory-watch cadence.

                                                    Instances For

                                                      State for a simple CUDA-memory drift detector.

                                                      The first reported sample becomes the baseline. Later samples compare current CUDA free memory against that baseline and warn once if the observed per-step drop projects failure before the requested run length.

                                                      Instances For

                                                        Maybe print a one-line CUDA allocator report.

                                                        The report samples the native allocator at a fixed cadence and warns if the observed free-memory slope would cross zero before the requested training horizon.

                                                        Instances For

                                                          Default Adam optimizer constructor used by supervised and vision examples.

                                                          The reusable part is the optimizer convention, not the model. Individual examples still own their architecture and loss, while this helper keeps the Adam hyperparameter spelling identical across MLP, CNN, ResNet, ViT, and similar model commands.

                                                          Instances For
                                                            def NN.API.Common.runAnyOrFloat (exeName : String) (args : List String) (preferFloat : List StringBool) (banner : Runtime.Autograd.Torch.OptionsString) (anyK : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → (Floatα)Runtime.Autograd.Torch.OptionsList StringIO Unit) (floatK : Runtime.Autograd.Torch.OptionsList StringIO Unit) (printOk : Bool := true) :

                                                            Run an executable with the standard TorchLean runtime parser, using the polymorphic scalar path by default and switching to the Float path when requested.

                                                            This is the common shape for public examples that support all executable scalar backends, but need the Float path for CUDA bridges, decoded reports, or JSON artifacts whose metrics are stored as Float.

                                                            Instances For
                                                              def NN.API.Common.runAnyOrFloatNoCast (exeName : String) (args : List String) (preferFloat : List StringBool) (banner : Runtime.Autograd.Torch.OptionsString) (anyK : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → Runtime.Autograd.Torch.OptionsList StringIO Unit) (floatK : Runtime.Autograd.Torch.OptionsList StringIO Unit) (printOk : Bool := true) :

                                                              Run an executable on either the selected scalar backend or the concrete Float path when the generic branch does not need an explicit Float → α cast helper.

                                                              Instances For

                                                                Run an executable on the concrete Float runtime path.

                                                                We use this for runnable training commands that produce Float-valued artifacts: CPU/CUDA eager execution, native kernels, and JSON loss curves. Commands that need to expose another scalar backend can use runAnyOrFloat.

                                                                Instances For

                                                                  Runtime-flag normalization #

                                                                  Detect --backend compiled in either --backend=compiled or split-flag form.

                                                                  Instances For
                                                                    def NN.API.Common.forceGpuArgs (exeName : String) (args : List String) (extraFlags : List String := ["--fast-kernels"]) :

                                                                    Reject --cpu and add CUDA/runtime flags expected by GPU-first commands.

                                                                    This helper only rewrites command-line intent before the standard runtime parser runs. It does not change the lower-level runtime semantics.

                                                                    Instances For

                                                                      Reject the proof-compiled backend for commands that require eager runtime execution.

                                                                      Instances For
                                                                        def NN.API.Common.forceGpuEagerArgs (exeName : String) (args : List String) (extraFlags : List String := ["--fast-kernels"]) :

                                                                        Reject compiled backend and force the CUDA-first flags expected by eager GPU examples.

                                                                        Instances For

                                                                          Run a Float-only command after normalizing its runtime flags.

                                                                          GPU-first examples use this to keep the public Runtime.runFloat path while inserting required CUDA/eager flags before the standard parser runs.

                                                                          Instances For
                                                                            def NN.API.Common.runGpuFloat (exeName : String) (args : List String) (banner : Runtime.Autograd.Torch.OptionsString) (k : Runtime.Autograd.Torch.OptionsList StringIO Unit) (extraFlags : List String := ["--fast-kernels"]) (printOk : Bool := true) :

                                                                            Run a Float-only command after forcing CUDA runtime flags.

                                                                            Instances For
                                                                              def NN.API.Common.runGpuEagerFloat (exeName : String) (args : List String) (banner : Runtime.Autograd.Torch.OptionsString) (k : Runtime.Autograd.Torch.OptionsList StringIO Unit) (extraFlags : List String := ["--fast-kernels"]) (printOk : Bool := true) :

                                                                              Run a Float-only command after forcing CUDA eager-runtime flags.

                                                                              Instances For

                                                                                Common model-run parsers #

                                                                                Shared corpus-window or training-window count used by finite cyclic examples.

                                                                                • windows :

                                                                                  Number of windows used by the training set, sampler, or cyclic schedule.

                                                                                Instances For

                                                                                  Parse the shared --windows flag.

                                                                                  The model decides how to use the windows; this helper just keeps the flag spelling and positivity check consistent.

                                                                                  Instances For

                                                                                    Optional parameter-checkpoint load/save paths shared by runnable model commands.

                                                                                    Instances For

                                                                                      Parse the shared --load-params / --save-params flags.

                                                                                      Instances For

                                                                                        Diffusion schedule knobs shared by model-zoo diffusion commands.

                                                                                        • T :

                                                                                          Number of diffusion timesteps in the schedule.

                                                                                        • betaStart : Float

                                                                                          First beta value in the schedule.

                                                                                        • betaEnd : Float

                                                                                          Final beta value in the schedule.

                                                                                        Instances For
                                                                                          def NN.API.Common.DiffusionScheduleFlags.parse (args : List String) (defaultT : := 100) (defaultBetaStart : Float := 1e-4) (defaultBetaEnd : Float := 0.12) :

                                                                                          Parse shared diffusion schedule flags: --T, --beta-start, and --beta-end.

                                                                                          Instances For

                                                                                            Standard TrainLog metadata for a diffusion schedule.

                                                                                            Instances For

                                                                                              Optional image artifacts emitted by image-generation or reconstruction commands.

                                                                                              Instances For

                                                                                                Parse shared image artifact flags for generation/reconstruction commands.

                                                                                                This only parses artifact paths and the optional reconstruction timestep. The model still decides which images it can write.

                                                                                                Instances For

                                                                                                  Standard TrainLog metadata for requested image artifacts.

                                                                                                  Instances For

                                                                                                    Common paired-NPY dataset flags for scientific supervised examples.

                                                                                                    This shape is for commands that train on one (x,y) tensor pair and evaluate/report on a held-out pair. The model file supplies the tensor shapes and file defaults.

                                                                                                    • trainRows :

                                                                                                      Number of rows loaded from the prepared training tensors.

                                                                                                    • testRows :

                                                                                                      Number of rows loaded from the prepared held-out tensors.

                                                                                                    • evalRows :

                                                                                                      Prefix length used for deterministic train/test loss reports.

                                                                                                    • Training input .npy path.

                                                                                                    • Training target .npy path.

                                                                                                    • Held-out input .npy path.

                                                                                                    • Held-out target .npy path.

                                                                                                    Instances For
                                                                                                      def NN.API.Common.PairedNpyEvalFlags.parse (args : List String) (defaultTrainX defaultTrainY defaultTestX defaultTestY : System.FilePath) (defaultTrainRows defaultTestRows : ) (defaultEvalRows : := 16) :

                                                                                                      Parse common train/test paired-NPY flags.

                                                                                                      Commands supply their default paths and row counts. This parser keeps the repeated --train-rows, --test-rows, --eval-rows, --x, --y, --test-x, and --test-y flags in one place.

                                                                                                      Instances For

                                                                                                        Standard TrainLog metadata for paired train/test NPY tensors.

                                                                                                        Instances For

                                                                                                          Optional CSV artifact path for commands that emit one tabular diagnostic.

                                                                                                          Instances For

                                                                                                            Parse the shared optional --plot-csv artifact path.

                                                                                                            Instances For

                                                                                                              Generic NPY-backed labeled dataset flags.

                                                                                                              Examples provide default paths and data-preparation hints. The repeated flags are --seed, --n-total, --x, and --y.

                                                                                                              • Prepared feature/image tensor path.

                                                                                                              • Prepared label/target tensor path.

                                                                                                              • nRows :

                                                                                                                Number of rows to read from the prepared arrays.

                                                                                                              • seed :

                                                                                                                Data-loader seed.

                                                                                                              Instances For
                                                                                                                def NN.API.Common.NpyDataFlags.parse (args : List String) (defaultX defaultY : System.FilePath) (defaultRows : ) :

                                                                                                                Parse the standard --seed, --n-total, --x, and --y flags for NPY-backed datasets.

                                                                                                                Examples provide dataset-specific defaults; this parser keeps the repeated NPY dataset flags consistent.

                                                                                                                Instances For

                                                                                                                  Standard TrainLog metadata for an NPY-backed dataset branch.

                                                                                                                  Instances For

                                                                                                                    Built-in image dataset branches shared by image-model commands.

                                                                                                                    Instances For

                                                                                                                      Parse --dataset, --cifar10, and --imagenet64, rejecting ambiguous selectors.

                                                                                                                      Instances For

                                                                                                                        Shared forecasting-window data flags: paths, window count, report offset, and seed.

                                                                                                                        • Prepared input-window tensor path.

                                                                                                                        • Prepared target-window tensor path.

                                                                                                                        • windows :

                                                                                                                          Number of forecasting windows to use.

                                                                                                                        • reportOffset :

                                                                                                                          Report window index used for before/after forecast display.

                                                                                                                        • seed :

                                                                                                                          Data-loader seed.

                                                                                                                        Instances For

                                                                                                                          Standard TrainLog metadata for forecasting-window datasets.

                                                                                                                          Instances For

                                                                                                                            Parse a supervised NPY dataset, logged training flags, and require that no command-specific arguments remain.

                                                                                                                            Dataset-specific commands provide parseData, which owns defaults such as CIFAR or ImageNet paths; this helper keeps the reusable "NPY data + logged TrainLog" path in one place.

                                                                                                                            Instances For
                                                                                                                              def NN.API.Common.parseNpyModelTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (defaultLr : Float := 1e-3) (parseData : List StringExcept String (NpyDataFlags × List String)) :

                                                                                                                              Parse a supervised NPY dataset and standard model-training flags.

                                                                                                                              The remaining arguments are returned so callers can still pass runtime/backend flags through a higher-level runner.

                                                                                                                              Instances For
                                                                                                                                def NN.API.Common.parseForecastWindowModelTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 100) (defaultLr : Float := 1e-2) (parseData : List StringExcept String (ForecastWindowDataFlags × List String)) :

                                                                                                                                Parse forecasting-window data flags plus standard model-training flags.

                                                                                                                                The data parser comes from the caller because file defaults often depend on a dataset directory or a preparation script.

                                                                                                                                Instances For

                                                                                                                                  Common arguments for a model command that reads one supervised CSV.

                                                                                                                                  Instances For
                                                                                                                                    def NN.API.Common.parseCsvModelTrainFlags (exeName : String) (args : List String) (defaultCsv defaultLogPath : System.FilePath) (defaultSteps : := 1) (defaultLr : Float := 1e-3) (allowZeroSteps : Bool := false) :

                                                                                                                                    Parse common flags for a supervised CSV model runner.

                                                                                                                                    Model files choose their default CSV, default log path, step count, and learning rate.

                                                                                                                                    Instances For
                                                                                                                                      def NN.API.Common.listGen {α : Type} (n : ) (f : α) :
                                                                                                                                      List α

                                                                                                                                      List generator: [0, 1, ..., n-1] mapped through f.

                                                                                                                                      Instances For
                                                                                                                                        def NN.API.Common.tensorF {α : Type} [Context α] (cast : Floatα) (dims : List ) (xs : List Float) :

                                                                                                                                        Build an N-D tensor from a raw list of floats and cast it into the chosen scalar backend.

                                                                                                                                        Fails if xs.length ≠ numel(dims).

                                                                                                                                        Instances For
                                                                                                                                          def NN.API.Common.tensorFGen {α : Type} [Context α] (cast : Floatα) (dims : List ) (f : Float) :

                                                                                                                                          Generate an N-D tensor by calling f for each element index, then cast into the chosen backend.

                                                                                                                                          The function f is indexed by the flat element index 0..numel-1.

                                                                                                                                          Instances For
                                                                                                                                            def NN.API.Common.tensorFGen! {α : Type} [Context α] (cast : Floatα) (dims : List ) (f : Float) :

                                                                                                                                            Generate an N-D tensor by calling f for each flat element index, with no failure case.

                                                                                                                                            This is the “total” sibling of tensorFGen: since we generate exactly numel(dims) values, the reshape cannot fail, so we avoid an Except. When you want to build a deterministic constant tensor for an example, this is usually the right tool.

                                                                                                                                            Instances For
                                                                                                                                              def NN.API.Common.tensorFGenShape! {α : Type} [Context α] (cast : Floatα) (s : Spec.Shape) (f : Float) :

                                                                                                                                              Generate a tensor of a known shape s by calling f for each flat element index, then cast into the chosen backend.

                                                                                                                                              This packages the standard shape-cast pattern used when a tensor is generated from flat indices:

                                                                                                                                              let xDyn : Tensor α (shapeOfDims s.toList) := ...
                                                                                                                                              let x : Tensor α s := by simpa using xDyn
                                                                                                                                              
                                                                                                                                              Instances For

                                                                                                                                                1D vector tensor constructor specialized to shape Vec n.

                                                                                                                                                Fails if xs.length ≠ n.

                                                                                                                                                Instances For
                                                                                                                                                  def NN.API.Common.vecFGen {α : Type} [Context α] (cast : Floatα) (n : ) (f : Float) :

                                                                                                                                                  Generator variant of vecF.

                                                                                                                                                  Instances For
                                                                                                                                                    def NN.API.Common.matF {α : Type} [Context α] (cast : Floatα) (rows cols : ) (xs : List Float) :

                                                                                                                                                    2D matrix tensor constructor specialized to shape Mat rows cols.

                                                                                                                                                    Fails if xs.length ≠ rows * cols.

                                                                                                                                                    Instances For
                                                                                                                                                      def NN.API.Common.matFGen {α : Type} [Context α] (cast : Floatα) (rows cols : ) (f : Float) :

                                                                                                                                                      Generator variant of matF.

                                                                                                                                                      Instances For
                                                                                                                                                        def NN.API.Common.runWithDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → (Floatα)IO Unit) :

                                                                                                                                                        Run a workflow once under a dtype selected from args (via --dtype / --float32-mode).

                                                                                                                                                        This logs the chosen dtype and then calls k with a cast function Float → α for the selected scalar backend.

                                                                                                                                                        In particular:

                                                                                                                                                        • --dtype=float selects Lean's builtin Float (trusted semantics, executable),
                                                                                                                                                        • --dtype=float32 selects TorchLean's verified IEEE32 executable semantics,
                                                                                                                                                        • --dtype=complex selects TorchLean's parametric complex scalar over Float32,
                                                                                                                                                        • --dtype=real selects (proof-only; errors at runtime).
                                                                                                                                                        Instances For
                                                                                                                                                          def NN.API.Common.runWithRuntimeDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → IO Unit) :

                                                                                                                                                          Like runWithDType, but also provides an API.Runtime.Scalar α instance.

                                                                                                                                                          Use this when your workflow uses numeric literals (1.0, -3.5, etc.) at runtime.

                                                                                                                                                          Instances For
                                                                                                                                                            @[reducible, inline]
                                                                                                                                                            abbrev NN.API.Common.mainWithRuntimeDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → IO Unit) :

                                                                                                                                                            Entry-point alias for runnable binaries that need runtime float-literal injection.

                                                                                                                                                            Instances For