TorchLean API

NN.API.Common

NN.API.Common #

Small, practical helpers used across TorchLean workflows and entrypoints: typed tensor constructors, casting between scalar backends, and Except/IO glue.

Common Helpers For Workflows And Small Programs #

We keep this module small and practical: it contains helper functions that show up repeatedly in executable workflows and tutorials.

What belongs here:

What does not belong here:

PyTorch Mapping Notes #

If you're coming from PyTorch, this module plays the role of the small shared utility layer often written around:

The main difference is that TorchLean tracks shapes in types, so tensor constructors here return Except String ... rather than silently reshaping.

def NN.API.Common.castTensor {α : Type} (cast : Floatα) {s : Spec.Shape} (t : Spec.Tensor Float s) :

Cast a tensor elementwise while preserving its type-level shape.

Instances For
    def NN.API.Common.orThrow {α : Type} (tag : String := "Example") :
    Except String αIO α

    Convert an Except String α into an IO α, raising a tagged userError on failure.

    This is handy in main functions where we want the process to exit with a readable message.

    Instances For
      def NN.API.Common.check (tag msg : String) (b : Bool) :

      Fail with a tagged userError if a boolean condition is false.

      This is a small convenience for examples that want short, readable checks: Common.check exeName "loss finite" (loss == loss).

      Instances For
        def NN.API.Common.writeBeforeAfterLossLog (path : System.FilePath) (title : String) (steps : ) (loss0 loss1 : Float) (notes : Array String := #[]) :

        Write a standard JSON training artifact for routines that record an initial and final loss.

        This is a convenience wrapper around Runtime.Training.TrainLog.beforeAfterLoss and the stable TrainLog JSON format. It is intentionally independent of any particular model, dataset, or runtime backend.

        Instances For
          def NN.API.Common.writeBeforeAfterLossLogTo (dest : Runtime.Training.LogDestination) (title : String) (steps : ) (loss0 loss1 : Float) (notes : Array String := #[]) :

          Write a before/after loss log to an explicit logging destination.

          LogDestination.disabled is a no-op, mirroring wandb disabled for quick smoke tests.

          Instances For
            def NN.API.Common.writeCurveLog (path : System.FilePath) (title : String) (curve : Runtime.Training.Curve) (seriesName : String := "loss") (notes : Array String := #[]) :

            Write a one-series scalar curve as a standard TrainLog JSON artifact.

            Instances For
              def NN.API.Common.writeCurveLogTo (dest : Runtime.Training.LogDestination) (title : String) (curve : Runtime.Training.Curve) (seriesName : String := "loss") (notes : Array String := #[]) :

              Write a one-series scalar curve to an explicit logging destination.

              Instances For

                Common CLI result for training commands that accept --steps/--epochs and --log.

                Instances For
                  def NN.API.Common.parseLoggedTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (allowZeroSteps : Bool := false) :

                  Parse common training flags: positive --steps/--epochs plus optional --log.

                  --log <path> writes the standard local JSON artifact. --log false, --log off, or --log none disables artifact writing while preserving the parsed step count.

                  Instances For

                    Training flags shared by the runnable model examples.

                    Most model commands should not each define their own parser for the same knobs. They can parse model-specific data flags first, then call parseModelTrainFlags for the common optimizer loop:

                    • --steps / --epochs: how many optimizer passes the example should run;
                    • --log: where to write a TrainLog JSON artifact;
                    • --lr: Adam learning rate.

                    Special examples can still extend this record with extra fields, but the default path stays one shared parser rather than one local TrainOptions clone per model.

                    • Shared step/epoch count and logging destination.

                    • lr : Float

                      Learning rate for the default Adam optimizer used by examples.

                    Instances For
                      def NN.API.Common.parseModelTrainFlags (exeName : String) (args : List String) (defaultLogPath : System.FilePath) (defaultSteps : := 1) (defaultLr : Float := 1e-3) (allowZeroSteps : Bool := false) :

                      Parse the standard model-training flags: --steps/--epochs, --log, and --lr.

                      Instances For

                        Default Adam optimizer constructor used by supervised and vision examples.

                        The reusable part is the optimizer convention, not the model. Individual examples still own their architecture and loss, while this helper keeps the Adam hyperparameter spelling identical across MLP, CNN, ResNet, ViT, and similar small demos.

                        Instances For
                          def NN.API.Common.runAnyOrFloat (exeName : String) (args : List String) (preferFloat : List StringBool) (banner : Runtime.Autograd.Torch.OptionsString) (anyK : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → (Floatα)Runtime.Autograd.Torch.OptionsList StringIO Unit) (floatK : Runtime.Autograd.Torch.OptionsList StringIO Unit) (printOk : Bool := true) :

                          Run an executable with the standard TorchLean runtime parser, using the polymorphic scalar path by default and switching to the Float path when requested.

                          This is the common shape for public examples that support all executable scalar backends, but need the Float path for CUDA bridges, decoded probes, or JSON artifacts whose metrics are stored as Float.

                          Instances For
                            def NN.API.Common.listGen {α : Type} (n : ) (f : α) :
                            List α

                            List generator: [0, 1, ..., n-1] mapped through f.

                            Instances For
                              def NN.API.Common.tensorF {α : Type} [Context α] (cast : Floatα) (dims : List ) (xs : List Float) :

                              Build an N-D tensor from a raw list of floats and cast it into the chosen scalar backend.

                              Fails if xs.length ≠ numel(dims).

                              Instances For
                                def NN.API.Common.tensorFGen {α : Type} [Context α] (cast : Floatα) (dims : List ) (f : Float) :

                                Generate an N-D tensor by calling f for each element index, then cast into the chosen backend.

                                The function f is indexed by the flat element index 0..numel-1.

                                Instances For
                                  def NN.API.Common.tensorFGen! {α : Type} [Context α] (cast : Floatα) (dims : List ) (f : Float) :

                                  Generate an N-D tensor by calling f for each flat element index, with no failure case.

                                  This is the “total” sibling of tensorFGen: since we generate exactly numel(dims) values, the reshape cannot fail, so we avoid an Except. When you want to build a deterministic constant tensor for an example, this is usually the right tool.

                                  Instances For
                                    def NN.API.Common.tensorFGenShape! {α : Type} [Context α] (cast : Floatα) (s : Spec.Shape) (f : Float) :

                                    Generate a tensor of a known shape s by calling f for each flat element index, then cast into the chosen backend.

                                    This is a convenience wrapper used in examples to avoid the common boilerplate:

                                    let xDyn : Tensor α (shapeOfDims s.toList) := ...
                                    let x : Tensor α s := by simpa using xDyn
                                    
                                    Instances For

                                      1D vector tensor constructor specialized to shape Vec n.

                                      Fails if xs.length ≠ n.

                                      Instances For
                                        def NN.API.Common.vecFGen {α : Type} [Context α] (cast : Floatα) (n : ) (f : Float) :

                                        Generator variant of vecF.

                                        Instances For
                                          def NN.API.Common.matF {α : Type} [Context α] (cast : Floatα) (rows cols : ) (xs : List Float) :

                                          2D matrix tensor constructor specialized to shape Mat rows cols.

                                          Fails if xs.length ≠ rows * cols.

                                          Instances For
                                            def NN.API.Common.matFGen {α : Type} [Context α] (cast : Floatα) (rows cols : ) (f : Float) :

                                            Generator variant of matF.

                                            Instances For
                                              def NN.API.Common.runWithDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → (Floatα)IO Unit) :

                                              Run a workflow once under a dtype selected from args (via --dtype / --float32-mode).

                                              This logs the chosen dtype and then calls k with a cast function Float → α for the selected scalar backend.

                                              In particular:

                                              • --dtype=float selects Lean's builtin Float (trusted semantics, executable),
                                              • --dtype=float32 selects TorchLean's verified IEEE32 executable semantics,
                                              • --dtype=complex selects TorchLean's parametric complex scalar over Float32,
                                              • --dtype=real selects (proof-only; errors at runtime).
                                              Instances For
                                                def NN.API.Common.runWithRuntimeDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → IO Unit) :

                                                Like runWithDType, but also provides an API.Runtime.Scalar α instance.

                                                Use this when your workflow uses numeric literals (1.0, -3.5, etc.) at runtime.

                                                Instances For
                                                  @[reducible, inline]
                                                  abbrev NN.API.Common.mainWithRuntimeDType (title : String) (args : List String) (k : {α : Type} → [Semantics.Scalar α] → [DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → IO Unit) :

                                                  Convenience wrapper for runnable binaries that need runtime float-literal injection.

                                                  Instances For