TorchLean API

NN.API.Public

API Public #

PyTorch-like facade over the TorchLean API.

Most user code should be able to import NN and then work with:

Most of the executable runtime machinery lives under API.TorchLean.*; this module collects the pieces into a smaller, PyTorch-shaped surface under NN.API.*.

PyTorch References #

This facade is inspired by the public shape of PyTorch:

TorchLean differs in two important ways:

This is the implementation module for the public facade. New user code should usually prefer import NN; use import NN.Entrypoint.API when you want only the PyTorch-shaped facade.

Facade policy:

@[reducible, inline]

Sequential model type (TorchLean Seq). This is the analogue of PyTorch nn.Sequential.

Instances For
    @[reducible, inline]
    abbrev NN.API.nn.LayerDef (σ τ : Spec.Shape) :

    Single-layer definition type (TorchLean LayerDef). This is the analogue of PyTorch nn.Module.

    Instances For

      Re-export common Seq helpers under API.nn.* so examples can stay on the public facade.

      This intentionally mirrors the TorchLean names to keep the mapping obvious.

      def NN.API.nn.of {σ τ : Spec.Shape} (layer : LayerDef σ τ) :

      Lift a single layer definition into a sequential model.

      Instances For

        All explicit-seed layer constructors live under nn.pure.*.

        The top-level nn.* namespace is reserved for the seeded builder API that allocates initialization seeds automatically (PyTorch-style ergonomics).

        def NN.API.nn.pure.linear (inDim outDim : ) (seedW seedB : := 0) (pfx : Spec.Shape := Spec.Shape.scalar) :
        Sequential (pfx.appendDim inDim) (pfx.appendDim outDim)

        Linear layer on the last axis (prefix-shape preserving).

        PyTorch analogue: torch.nn.Linear. See https://pytorch.org/docs/stable/generated/torch.nn.Linear.html.

        Unlike the lower-level TorchLean layer constructor (which is vector-only), this public facade matches PyTorch’s convention:

        • if x has shape [..., inDim], linear inDim outDim returns a model of shape [..., outDim].

        The leading “prefix” dimensions are treated as a batch (they are flattened to (numel(prefix), inDim), the affine map is applied once, and the result is reshaped back).

        Instances For
          def NN.API.nn.pure.rnn (seqLen inputSize hiddenSize : ) (seedW seedB : := 0) :
          Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize)

          Vanilla RNN layer (time-major sequence, no batch axis).

          Semantics: h_t = tanh(W [x_t; h_{t-1}] + b), with h_{-1} = 0.

          This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

          PyTorch analogy: torch.nn.RNN(inputSize, hiddenSize, nonlinearity="tanh") with batch_first=false, specialized to a single batch element.

          Instances For
            def NN.API.nn.pure.gru (seqLen inputSize hiddenSize : ) (seedW seedB : := 0) :
            Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize)

            GRU layer (time-major sequence, no batch axis).

            This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

            PyTorch analogy: torch.nn.GRU(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

            Instances For
              def NN.API.nn.pure.mamba (seqLen inputSize hiddenSize : ) (seedW seedB : := 0) :
              Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize)

              Trainable Mamba-style gated diagonal state-space layer.

              The layer is time-major and single-batch, matching the simple rnn/gru/lstm constructors: input (seqLen × inputSize), output (seqLen × hiddenSize). It is unrolled with differentiable TorchLean ops, so CPU and CUDA training use the same API.

              Instances For
                def NN.API.nn.pure.lstm (seqLen inputSize hiddenSize : ) (seedW seedB : := 0) :
                Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize)

                LSTM layer (time-major sequence, no batch axis).

                This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                PyTorch analogy: torch.nn.LSTM(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

                Instances For

                  Embedding table initialization configuration (one-hot / token-distribution inputs).

                  This is the TorchLean-friendly analogue of torch.nn.Embedding in the common demo setting where token ids are represented as one-hot vectors (or soft token distributions), so lookup is a matrix multiplication rather than integer indexing.

                  Instances For
                    def NN.API.nn.pure.embedding (vocab embedDim : ) (cfg : Embedding := { }) (pfx : Spec.Shape := Spec.Shape.scalar) :
                    Sequential (pfx.appendDim vocab) (pfx.appendDim embedDim)

                    Embedding layer for one-hot / token-distribution inputs (no bias).

                    Input shape: [..., vocab] Output shape: [..., embedDim]

                    PyTorch analogue: conceptually nn.Embedding(vocab, embedDim) but applied to one-hot inputs.

                    Instances For

                      Learned positional embedding configuration.

                      This is a trainable parameter tensor of shape (seqLen × embedDim) that is broadcast across the leading batch dimension and added to the input.

                      Instances For
                        def NN.API.nn.pure.learnedPositionalEmbedding {batch seqLen embedDim : } (cfg : LearnedPositionalEmbedding := { }) :
                        Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                        Add learned positional embeddings to a batched (batch × seqLen × embedDim) tensor.

                        PyTorch analogue: x + pos[:seqLen] where pos is a parameter table.

                        Instances For

                          Sinusoidal positional encoding configuration.

                          This is the classic (non-trainable) Transformer sinusoidal encoding, added to token embeddings. startPos is an absolute-position offset (useful for KV-cache decoding).

                          • startPos :

                            Absolute position offset for the first row of the encoding table.

                          Instances For
                            def NN.API.nn.pure.sinusoidalPositionalEncoding {batch seqLen embedDim : } (cfg : SinusoidalPositionalEncoding := { }) :
                            Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                            Add sinusoidal positional encodings to a batched (batch × seqLen × embedDim) tensor.

                            Implementation:

                            • precompute PE : (seqLen × embedDim) at initialization time (stored as a non-trainable buffer),
                            • broadcast it across the leading batch axis and add to the input.
                            Instances For

                              Rotary positional embedding (RoPE) configuration.

                              startPos is an absolute-position offset (useful for KV-cache decoding).

                              • startPos :

                                Absolute position offset for the first row of RoPE angles.

                              Instances For
                                def NN.API.nn.pure.rope {batch numHeads seqLen headDim : } (cfg : RoPE := { }) :
                                Sequential (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim))) (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim)))

                                Apply RoPE to a batched multi-head tensor (batch × numHeads × seqLen × headDim).

                                This matches the standard identity:

                                rope(x) = x * cos + rotatePairs(x) * sin

                                where cos/sin depend only on (pos, dim) and broadcast across (batch, numHeads).

                                Notes:

                                • This layer is differentiable (gradients flow through the rotation), but it has no trainable parameters; the precomputed cos/sin tables are stored as non-trainable buffers.
                                • The pure spec version is in NN.Spec.Layers.PositionalEncoding (Spec.rope_apply_heads_spec).
                                Instances For

                                  Elementwise ReLU. PyTorch analogue: torch.nn.ReLU / torch.nn.functional.relu.

                                  Instances For

                                    Elementwise SiLU/Swish. PyTorch analogue: torch.nn.SiLU / torch.nn.functional.silu.

                                    Instances For

                                      Elementwise GELU. PyTorch analogue: torch.nn.GELU / torch.nn.functional.gelu.

                                      Instances For

                                        Elementwise sigmoid. PyTorch analogue: torch.nn.Sigmoid / torch.nn.functional.sigmoid.

                                        Instances For

                                          Elementwise tanh. PyTorch analogue: torch.nn.Tanh / torch.nn.functional.tanh.

                                          Instances For

                                            Softmax. PyTorch analogue: torch.nn.Softmax / torch.nn.functional.softmax.

                                            Instances For

                                              Reduce-sum to a scalar. PyTorch analogue: torch.sum.

                                              Instances For

                                                Flatten any tensor into a 1D vector of length size s. PyTorch analogue: torch.flatten.

                                                Instances For

                                                  Flatten a batched tensor N × σ into a matrix N × (size σ).

                                                  PyTorch analogue: torch.flatten(x, start_dim=1).

                                                  Instances For

                                                    Flatten a batched tensor starting at dimension 1 (keep dim0).

                                                    Synonym for flattenBatch, matching PyTorch’s start_dim=1 wording.

                                                    Instances For
                                                      def NN.API.nn.pure.dropout {s : Spec.Shape} (p : Float) (seed : := 0) :

                                                      Dropout layer (active in train mode, identity in eval mode).

                                                      PyTorch analogue: torch.nn.Dropout.

                                                      Instances For
                                                        def NN.API.nn.pure.flattenLinear {s : Spec.Shape} (outDim : ) (seedW seedB : := 0) :

                                                        Convenience block: Flatten -> Linear.

                                                        This is common for "image to classifier head" demos.

                                                        Instances For

                                                          nn.functional mirrors torch.nn.functional: pure, stateless building blocks.

                                                          In TorchLean these are defined as derived ops over the small primitive Ops surface, so the same code works on both the eager backend and the compiled backend.

                                                          PyTorch references:

                                                          Batch Lifting #

                                                          batchDim0 n model wraps a single-example model σ → τ into a batched model (dim n σ) → (dim n τ) by running the underlying model once per batch element.

                                                          This is a correctness-first helper used to expose PyTorch-like N×... APIs even when a primitive only exists for the unbatched shape.

                                                          Lift a single-example LayerDef σ τ to operate on a dimension-0 batch.

                                                          This is a correctness-first helper: it runs the underlying layer independently on each batch element. Prefer a primitive batched layer when one exists.

                                                          Instances For

                                                            Lift a sequential model to act pointwise on a leading dim0 batch axis.

                                                            Instances For

                                                              Note: some low-level TorchLean layers (notably conv/pool/norm) have Nat-side well-formedness proof arguments (e.g. kH ≠ 0).

                                                              The public path is record-based specs that hide those proofs via typeclasses like NeZero, so examples can stay PyTorch-like without relying on positional macros.

                                                              Named-field Conv2d configuration (CHW layout).

                                                              This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

                                                              • outC :

                                                                Output channels.

                                                              • kH :

                                                                Kernel height.

                                                              • kW :

                                                                Kernel width.

                                                              • stride :

                                                                Stride (shared for height/width).

                                                              • padding :

                                                                Zero-padding (shared for height/width).

                                                              • seedK :

                                                                Seed for deterministic kernel initialization.

                                                              • seedB :

                                                                Seed for deterministic bias initialization.

                                                              • Initialization scheme for the kernel weights.

                                                              Instances For
                                                                @[reducible, inline]

                                                                Named-field Conv2d configuration (CHW layout).

                                                                This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

                                                                Instances For
                                                                  def NN.API.nn.pure.conv2dCHWWith {inC inH inW : } (cfg : Conv2d) (hInC : inC 0) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                  Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                  2D convolution over a CHW tensor, using explicit well-formedness proofs.

                                                                  Instances For
                                                                    def NN.API.nn.pure.conv2dCHW {inC inH inW : } (cfg : Conv2d) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                    Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                    2D convolution over a CHW tensor, with a PyTorch-like named-field spec.

                                                                    This hides the Nat-side proof arguments via the NeZero typeclass.

                                                                    Instances For
                                                                      def NN.API.nn.pure.conv2d {n inC inH inW : } (cfg : Conv2d) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                      Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                      2D convolution over a batched image tensor (shape N×C×H×W, like PyTorch).

                                                                      Instances For
                                                                        def NN.API.nn.pure.convCHWWith {inC inH inW : } (cfg : Conv2d) (hInC : inC 0) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                        Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                        2D convolution over a CHW tensor, using explicit well-formedness proofs.

                                                                        Instances For
                                                                          def NN.API.nn.pure.convCHW {inC inH inW : } (cfg : Conv2d) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                          Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                          2D convolution over a CHW tensor, with a PyTorch-like named-field spec.

                                                                          This hides the Nat-side proof arguments via the NeZero typeclass.

                                                                          Instances For
                                                                            def NN.API.nn.pure.conv {n inC inH inW : } (cfg : Conv) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                            Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1))

                                                                            Convolution over batched CHW images, using the PyTorch-style Conv2d config record.

                                                                            Shorthand for conv2d.

                                                                            Instances For

                                                                              MaxPool2d configuration for CHW inputs.

                                                                              PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                                                                              • kH :

                                                                                Kernel height.

                                                                              • kW :

                                                                                Kernel width.

                                                                              • stride :

                                                                                Stride (shared for height/width).

                                                                              Instances For
                                                                                @[reducible, inline]

                                                                                MaxPool2d configuration for CHW inputs.

                                                                                PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                                                                                Instances For
                                                                                  def NN.API.nn.pure.maxPool2dWith {inC inH inW : } (cfg : MaxPool2d) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                                  Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                  MaxPool2d with explicit nonzero kernel proofs.

                                                                                  Instances For
                                                                                    def NN.API.nn.pure.maxPool2dCHW {inC inH inW : } (cfg : MaxPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                    Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                    MaxPool2d over CHW inputs using NeZero to hide nonzero kernel proofs.

                                                                                    Instances For
                                                                                      def NN.API.nn.pure.maxPool2d {n inC inH inW : } (cfg : MaxPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                      Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                      MaxPool2d using NeZero to hide nonzero kernel proofs.

                                                                                      Instances For
                                                                                        def NN.API.nn.pure.maxPoolWith {inC inH inW : } (cfg : MaxPool2d) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                                        Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                        Shorthand for maxPool2dWith (PyTorch-style).

                                                                                        Instances For
                                                                                          def NN.API.nn.pure.maxPoolCHW {inC inH inW : } (cfg : MaxPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                          Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                          Shorthand for maxPool2dCHW (PyTorch-style).

                                                                                          Instances For
                                                                                            def NN.API.nn.pure.maxPool {n inC inH inW : } (cfg : MaxPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                            Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                            Max pooling over batched CHW images, using the PyTorch-style MaxPool2d config record.

                                                                                            Shorthand for maxPool2d.

                                                                                            Instances For

                                                                                              AvgPool2d configuration for CHW inputs.

                                                                                              PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                                                                                              • kH :

                                                                                                Kernel height.

                                                                                              • kW :

                                                                                                Kernel width.

                                                                                              • stride :

                                                                                                Stride (shared for height/width).

                                                                                              Instances For
                                                                                                @[reducible, inline]

                                                                                                AvgPool2d configuration for CHW inputs.

                                                                                                PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                                                                                                Instances For
                                                                                                  def NN.API.nn.pure.avgPool2dWith {inC inH inW : } (cfg : AvgPool2d) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                                                  Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                  AvgPool2d with explicit nonzero kernel proofs.

                                                                                                  Instances For
                                                                                                    def NN.API.nn.pure.avgPool2dCHW {inC inH inW : } (cfg : AvgPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                    Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                    AvgPool2d over CHW inputs using NeZero to hide nonzero kernel proofs.

                                                                                                    Instances For
                                                                                                      def NN.API.nn.pure.avgPool2d {n inC inH inW : } (cfg : AvgPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                      Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                      AvgPool2d over batched NCHW inputs (shape N×C×H×W, like PyTorch).

                                                                                                      Instances For
                                                                                                        def NN.API.nn.pure.avgPoolWith {inC inH inW : } (cfg : AvgPool2d) (hKH : cfg.kH 0) (hKW : cfg.kW 0) :
                                                                                                        Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                        Shorthand for avgPool2dWith (PyTorch-style).

                                                                                                        Instances For
                                                                                                          def NN.API.nn.pure.avgPoolCHW {inC inH inW : } (cfg : AvgPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                          Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                          Shorthand for avgPool2dCHW (PyTorch-style).

                                                                                                          Instances For
                                                                                                            def NN.API.nn.pure.avgPool {n inC inH inW : } (cfg : AvgPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                            Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1))

                                                                                                            Average pooling over batched CHW images, using the PyTorch-style AvgPool2d config record.

                                                                                                            Shorthand for avgPool2d.

                                                                                                            Instances For
                                                                                                              def NN.API.nn.pure.globalAvgPoolCHW (c h w : ) {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                                                                                              Global average pooling over a CHW tensor.

                                                                                                              PyTorch analogue: torch.nn.AdaptiveAvgPool2d((1, 1)) followed by flattening.

                                                                                                              Instances For
                                                                                                                def NN.API.nn.pure.globalAvgPoolNCHW (n c h w : ) {hN : n > 0} {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                                                                                                Global average pooling over an NCHW tensor (preserves the batch dimension).

                                                                                                                Instances For

                                                                                                                  LayerNorm configuration for batched (batch x seqLen x embedDim) tensors.

                                                                                                                  PyTorch analogue: torch.nn.LayerNorm. See https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.

                                                                                                                  • seedGamma :

                                                                                                                    Seed for deterministic initialization of gamma (scale).

                                                                                                                  • seedBeta :

                                                                                                                    Seed for deterministic initialization of beta (shift).

                                                                                                                  Instances For
                                                                                                                    def NN.API.nn.pure.layerNormWith {batch seqLen embedDim : } (cfg : LayerNorm) (hSeq : seqLen > 0) (hEmbed : embedDim > 0) :
                                                                                                                    Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                                                                                                                    Layer normalization over (batch × seqLen × embedDim) tensors, with explicit positivity proofs.

                                                                                                                    This matches the common Transformer usage: normalize each token’s embedDim-vector independently, with learnable scale/shift parameters gamma and beta.

                                                                                                                    PyTorch analogue: torch.nn.LayerNorm(embedDim) applied to a tensor of shape (batch, seqLen, embedDim).

                                                                                                                    Most users should call nn.layerNorm, which uses NeZero to discharge the positivity proofs.

                                                                                                                    Instances For
                                                                                                                      def NN.API.nn.pure.layerNorm {batch seqLen embedDim : } (cfg : LayerNorm := { }) [NeZero seqLen] [NeZero embedDim] :
                                                                                                                      Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                                                                                                                      Layer normalization over (batch × seqLen × embedDim) tensors.

                                                                                                                      This normalizes each embedDim-vector (per batch element, per sequence position), and applies learned affine parameters gamma and beta.

                                                                                                                      PyTorch analogue: torch.nn.LayerNorm(embedDim) on a tensor shaped (batch, seqLen, embedDim).

                                                                                                                      Implementation note: TorchLean uses NeZero to ensure seqLen and embedDim are positive, avoiding degenerate shapes.

                                                                                                                      Instances For

                                                                                                                        RMSNorm configuration for batched (batch x seqLen x embedDim) tensors.

                                                                                                                        This is a common alternative to LayerNorm in modern transformer architectures.

                                                                                                                        • seedGamma :

                                                                                                                          Seed for deterministic initialization of gamma (scale).

                                                                                                                        Instances For
                                                                                                                          def NN.API.nn.pure.rmsNormWith {batch seqLen embedDim : } (cfg : RMSNorm) (hSeq : seqLen > 0) (hEmbed : embedDim > 0) :
                                                                                                                          Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                                                                                                                          RMS normalization over (batch × seqLen × embedDim) tensors, with explicit positivity proofs.

                                                                                                                          This is like LayerNorm but without mean subtraction: we scale by the root-mean-square over the embedDim axis, and apply a learned scale gamma.

                                                                                                                          PyTorch analogue: many libraries provide an RMSNorm(embedDim) module; conceptually it is applied to tensors shaped (batch, seqLen, embedDim).

                                                                                                                          Most users should call nn.rmsNorm, which uses NeZero to discharge the positivity proofs.

                                                                                                                          Instances For
                                                                                                                            def NN.API.nn.pure.rmsNorm {batch seqLen embedDim : } (cfg : RMSNorm := { }) [NeZero seqLen] [NeZero embedDim] :
                                                                                                                            Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim))

                                                                                                                            RMS normalization over (batch × seqLen × embedDim) tensors.

                                                                                                                            This normalizes by the root-mean-square over the embedDim axis (per batch element, per position), then applies a learned scale gamma.

                                                                                                                            Implementation note: TorchLean uses NeZero to ensure seqLen and embedDim are positive, avoiding degenerate shapes.

                                                                                                                            Instances For

                                                                                                                              BatchNorm2d configuration (learned scale/shift).

                                                                                                                              PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                                                                                                                              • seedGamma :

                                                                                                                                Seed for deterministic initialization of gamma (scale).

                                                                                                                              • seedBeta :

                                                                                                                                Seed for deterministic initialization of beta (shift).

                                                                                                                              Instances For
                                                                                                                                def NN.API.nn.pure.batchNorm2dNCHWWith {n c h w : } (cfg : BatchNorm2d) (hN : n > 0) (hC : c > 0) (hH : h > 0) (hW : w > 0) :

                                                                                                                                BatchNorm2d over NCHW inputs (train/eval is handled by Seq mode).

                                                                                                                                Instances For

                                                                                                                                  BatchNorm2d over NCHW inputs, using NeZero to hide the positivity proofs.

                                                                                                                                  PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                                                                                                                                  Instances For

                                                                                                                                    InstanceNorm2d configuration (learned scale/shift).

                                                                                                                                    PyTorch analogue: torch.nn.InstanceNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html.

                                                                                                                                    • seedGamma :

                                                                                                                                      Seed for deterministic initialization of gamma (scale).

                                                                                                                                    • seedBeta :

                                                                                                                                      Seed for deterministic initialization of beta (shift).

                                                                                                                                    Instances For
                                                                                                                                      def NN.API.nn.pure.instanceNorm2dWith {n c h w : } (cfg : InstanceNorm2d) (hN : n > 0) (hC : c > 0) (hH : h > 0) (hW : w > 0) :

                                                                                                                                      InstanceNorm2d over NCHW inputs, using explicit positivity proofs.

                                                                                                                                      PyTorch analogue: torch.nn.InstanceNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html.

                                                                                                                                      Instances For

                                                                                                                                        InstanceNorm2d over NCHW inputs, using NeZero to hide the positivity proofs.

                                                                                                                                        PyTorch analogue: torch.nn.InstanceNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html.

                                                                                                                                        Instances For
                                                                                                                                          def NN.API.nn.pure.groupNorm2dNCHW (n c h w groups : ) {hN : n > 0} {hC : c > 0} {hH : h > 0} {hW : w > 0} {hG : groups > 0} (hGE : c groups) (hDiv : c % groups = 0) (seedGamma seedBeta : := 0) :

                                                                                                                                          GroupNorm over NCHW inputs.

                                                                                                                                          PyTorch analogue: torch.nn.GroupNorm. See https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html.

                                                                                                                                          Instances For

                                                                                                                                            Multi-head self-attention configuration.

                                                                                                                                            PyTorch analogue: torch.nn.MultiheadAttention (conceptually). See https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html.

                                                                                                                                            • numHeads :

                                                                                                                                              Number of attention heads.

                                                                                                                                            • headDim :

                                                                                                                                              Per-head embedding dimension.

                                                                                                                                            • seedW :

                                                                                                                                              Base seed for deterministic parameter initialization.

                                                                                                                                            Instances For

                                                                                                                                              Multi-head self-attention with an explicit nonzero sequence length proof.

                                                                                                                                              If mask is provided, it is a boolean attention mask of shape (n × n) (e.g. causal masking).

                                                                                                                                              Instances For

                                                                                                                                                Multi-head self-attention using NeZero to hide the nonzero sequence length proof.

                                                                                                                                                If mask is provided, it is a boolean attention mask of shape (n × n) (e.g. causal masking).

                                                                                                                                                Instances For

                                                                                                                                                  Small set of activation choices for block builders.

                                                                                                                                                  PyTorch analogues:

                                                                                                                                                  • relu <-> torch.nn.ReLU
                                                                                                                                                  • gelu <-> torch.nn.GELU
                                                                                                                                                  • silu <-> torch.nn.SiLU
                                                                                                                                                  • tanh <-> torch.nn.Tanh
                                                                                                                                                  • sigmoid <-> torch.nn.Sigmoid
                                                                                                                                                  Instances For

                                                                                                                                                    Interpret an Activation as a TorchLean layer.

                                                                                                                                                    Instances For

                                                                                                                                                      MLP (multi-layer perceptron) configuration.

                                                                                                                                                      This is a lightweight builder that produces a sequential stack of linear layers with activations and optional dropout.

                                                                                                                                                      PyTorch analogue: a hand-written nn.Sequential(Linear(...), ReLU(), ..., Linear(...)).

                                                                                                                                                      • hidden : List

                                                                                                                                                        Hidden layer widths (each entry creates a Linear -> Activation stage).

                                                                                                                                                      • activation : Activation

                                                                                                                                                        Activation used after each hidden linear layer.

                                                                                                                                                      • dropout? : Option Float

                                                                                                                                                        Optional dropout probability after each activation.

                                                                                                                                                      • seedBase :

                                                                                                                                                        Base seed used to deterministically initialize all linear layers (and dropout if present).

                                                                                                                                                      Instances For
                                                                                                                                                        def NN.API.nn.pure.blocks.mlpGo (act : Activation) (dropout? : Option Float) (inDim : ) (hidden : List ) (outDim seed : ) :

                                                                                                                                                        Internal recursion for mlp.

                                                                                                                                                        This builds the sequential stack stage-by-stage, threading a seed so each linear (and optional dropout) layer gets a deterministic initialization key.

                                                                                                                                                        Instances For
                                                                                                                                                          def NN.API.nn.pure.blocks.mlp (inDim outDim : ) (cfg : MLP := { }) :

                                                                                                                                                          Build an MLP as a sequential stack of linear layers and activations.

                                                                                                                                                          This is a small "PyTorch-shaped" helper: a typical call looks like: API.nn.blocks.mlp 784 10 { hidden := [128, 128], activation := .relu }.

                                                                                                                                                          Instances For

                                                                                                                                                            Conv2d + activation (+ optional dropout) block configuration (CHW layout).

                                                                                                                                                            This compact helper is used by vision examples before moving to larger curated blocks.

                                                                                                                                                            • conv : Conv2d

                                                                                                                                                              Conv hyperparameters and seeds.

                                                                                                                                                            • activation : Activation

                                                                                                                                                              Activation applied after the convolution.

                                                                                                                                                            • dropout? : Option Float

                                                                                                                                                              Optional dropout probability after the activation.

                                                                                                                                                            • seedDropout :

                                                                                                                                                              Seed for dropout RNG (only used when dropout? is present).

                                                                                                                                                            Instances For
                                                                                                                                                              def NN.API.nn.pure.blocks.conv2dAct {inC inH inW : } (cfg : Conv2dAct) [NeZero inC] [NeZero cfg.conv.kH] [NeZero cfg.conv.kW] :
                                                                                                                                                              Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.conv.outC ((inH + 2 * cfg.conv.padding - cfg.conv.kH) / cfg.conv.stride + 1) ((inW + 2 * cfg.conv.padding - cfg.conv.kW) / cfg.conv.stride + 1))

                                                                                                                                                              Conv2d -> Activation -> (optional Dropout) over CHW inputs.

                                                                                                                                                              Instances For

                                                                                                                                                                Vision blocks #

                                                                                                                                                                These are small, named-field building blocks intended for public examples:

                                                                                                                                                                They are intentionally conservative: the goal is readability and stable typing, not maximum coverage.

                                                                                                                                                                Configuration for a common vision block: Conv2d -> BatchNorm2d -> Activation -> (optional Dropout).

                                                                                                                                                                This is used by conv2dNormActCHW (single-image CHW) and conv2dNormAct (batched NCHW). We keep deterministic seed allocation explicit via seedBase so examples stay reproducible.

                                                                                                                                                                • conv : Conv2d

                                                                                                                                                                  Conv hyperparameters (seeds inside this record are ignored; use seedBase).

                                                                                                                                                                • activation : Activation

                                                                                                                                                                  Activation after normalization.

                                                                                                                                                                • dropout? : Option Float

                                                                                                                                                                  Optional dropout applied after the activation.

                                                                                                                                                                • seedBase :

                                                                                                                                                                  Base seed for deterministic init (derived seeds are allocated in a fixed order).

                                                                                                                                                                Instances For
                                                                                                                                                                  def NN.API.nn.pure.blocks.conv2dNormActCHW {inC inH inW : } (cfg : Conv2dNormAct) [NeZero inC] [NeZero cfg.conv.kH] [NeZero cfg.conv.kW] [NeZero cfg.conv.outC] :
                                                                                                                                                                  Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.conv.outC ((inH + 2 * cfg.conv.padding - cfg.conv.kH) / cfg.conv.stride + 1) ((inW + 2 * cfg.conv.padding - cfg.conv.kW) / cfg.conv.stride + 1))

                                                                                                                                                                  Conv2d -> BatchNorm -> Activation -> (optional Dropout), over a single CHW image (no batch axis).

                                                                                                                                                                  Seed allocation (relative to seedBase):

                                                                                                                                                                  Instances For

                                                                                                                                                                    Configuration for conv2dNormActPool*: a Conv2dNormAct block followed by max-pooling.

                                                                                                                                                                    This matches the common “conv-bn-act-pool” pattern used in small CNNs.

                                                                                                                                                                    • Conv/BN/activation/dropout block configuration.

                                                                                                                                                                    • pool : MaxPool2d

                                                                                                                                                                      Pooling hyperparameters (defaults to 2×2 stride-2 max pool).

                                                                                                                                                                    Instances For
                                                                                                                                                                      def NN.API.nn.pure.blocks.conv2dNormActPoolCHW {inC inH inW : } (cfg : Conv2dNormActPool) [NeZero inC] [NeZero cfg.block.conv.kH] [NeZero cfg.block.conv.kW] [NeZero cfg.block.conv.outC] [NeZero cfg.pool.kH] [NeZero cfg.pool.kW] :
                                                                                                                                                                      Sequential (Tensor.Shape.Image inC inH inW) (Tensor.Shape.Image cfg.block.conv.outC (((inH + 2 * cfg.block.conv.padding - cfg.block.conv.kH) / cfg.block.conv.stride + 1 - cfg.pool.kH) / cfg.pool.stride + 1) (((inW + 2 * cfg.block.conv.padding - cfg.block.conv.kW) / cfg.block.conv.stride + 1 - cfg.pool.kW) / cfg.pool.stride + 1))

                                                                                                                                                                      conv2dNormActCHW followed by MaxPool2dCHW.

                                                                                                                                                                      Instances For
                                                                                                                                                                        def NN.API.nn.pure.blocks.conv2dNormAct {n inC inH inW : } (cfg : Conv2dNormAct) [NeZero n] [NeZero inC] [NeZero cfg.conv.kH] [NeZero cfg.conv.kW] [NeZero cfg.conv.outC] :
                                                                                                                                                                        Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.conv.outC ((inH + 2 * cfg.conv.padding - cfg.conv.kH) / cfg.conv.stride + 1) ((inW + 2 * cfg.conv.padding - cfg.conv.kW) / cfg.conv.stride + 1))

                                                                                                                                                                        Conv2d -> BatchNorm2d -> Activation -> (optional Dropout), over batched image tensors (N×C×H×W).

                                                                                                                                                                        This is the public PyTorch-like path: examples should build CNNs directly over batched images.

                                                                                                                                                                        Instances For
                                                                                                                                                                          def NN.API.nn.pure.blocks.conv2dNormActPool {n inC inH inW : } (cfg : Conv2dNormActPool) [NeZero n] [NeZero inC] [NeZero cfg.block.conv.kH] [NeZero cfg.block.conv.kW] [NeZero cfg.block.conv.outC] [NeZero cfg.pool.kH] [NeZero cfg.pool.kW] :
                                                                                                                                                                          Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.block.conv.outC (((inH + 2 * cfg.block.conv.padding - cfg.block.conv.kH) / cfg.block.conv.stride + 1 - cfg.pool.kH) / cfg.pool.stride + 1) (((inW + 2 * cfg.block.conv.padding - cfg.block.conv.kW) / cfg.block.conv.stride + 1 - cfg.pool.kW) / cfg.pool.stride + 1))

                                                                                                                                                                          conv2dNormAct followed by MaxPool2d, over batched image tensors.

                                                                                                                                                                          Instances For

                                                                                                                                                                            Residual/skip-connection wrapper as a single LayerDef.

                                                                                                                                                                            Given inner : Seq s s, this builds a layer that computes x |-> inner(x) + x.

                                                                                                                                                                            PyTorch analogue: x + f(x) blocks used throughout ResNets and Transformers.

                                                                                                                                                                            Instances For

                                                                                                                                                                              Lift residualLayer into a sequential model.

                                                                                                                                                                              Instances For

                                                                                                                                                                                Branching (skip connections) #

                                                                                                                                                                                Seq is linear, but we sometimes want a PyTorch-like x |-> f(x) + g(x) block.

                                                                                                                                                                                We expose this as a single LayerDef whose parameter list is params(f) ++ params(g) and whose forward pass runs both programs and adds their outputs.

                                                                                                                                                                                Combine two sequential branches into a single layer that adds their outputs.

                                                                                                                                                                                The resulting layer runs both f and g on the same input x and returns f(x) + g(x). Parameters are concatenated as params(f) ++ params(g).

                                                                                                                                                                                Instances For

                                                                                                                                                                                  Combine two models with the same input/output shapes by summing their outputs.

                                                                                                                                                                                  This is a typed “residual add” helper: addBranches f g represents the model x ↦ f(x) + g(x), and its parameter list is the concatenation of the two branches’ parameter lists.

                                                                                                                                                                                  Instances For

                                                                                                                                                                                    ResNet BasicBlock #

                                                                                                                                                                                    We provide a typed and composable ResNet-18 style BasicBlock over CHW tensors.

                                                                                                                                                                                    Key idea: we use a small canonical stride-2 formula down2 (matching GraphSpec/Models/resnet18) so projection shortcuts typecheck cleanly without leaking Nat arithmetic at call sites.

                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                    Canonical stride-2 spatial downsampling formula used by ResNet blocks.

                                                                                                                                                                                    down2 h = (h - 1) / 2 + 1 = ceil(h / 2).

                                                                                                                                                                                    This matches the output-size formula for common stride-2 layers used in ResNet downsampling (e.g. 3×3 conv with padding 1, or 1×1 conv with padding 0).

                                                                                                                                                                                    Instances For

                                                                                                                                                                                      down2 is always positive (used to discharge NeZero goals).

                                                                                                                                                                                      theorem NN.API.nn.pure.blocks.conv3_same_out_eq {h : } (hh : h > 0) :
                                                                                                                                                                                      (h + 2 * 1 - 3) / 1 + 1 = h

                                                                                                                                                                                      Shape arithmetic helper: 3×3 conv with stride 1 and padding 1 preserves a positive spatial size.

                                                                                                                                                                                      This matches the standard conv output formula used by conv2dCHW.

                                                                                                                                                                                      theorem NN.API.nn.pure.blocks.conv1_same_out_eq {h : } (hh : h > 0) :
                                                                                                                                                                                      (h + 2 * 0 - 1) / 1 + 1 = h

                                                                                                                                                                                      Shape arithmetic helper: 1×1 conv with stride 1 and padding 0 preserves a positive spatial size.

                                                                                                                                                                                      ResNet helper: 3×3 convolution with padding 1, stride 1 (shape-preserving), over CHW images.

                                                                                                                                                                                      Instances For

                                                                                                                                                                                        ResNet helper: 3×3 convolution with padding 1, stride 2 (spatial downsampling via down2), over CHW images.

                                                                                                                                                                                        Instances For

                                                                                                                                                                                          ResNet helper: 1×1 convolution with stride 1 (shape-preserving), over CHW images.

                                                                                                                                                                                          Instances For

                                                                                                                                                                                            ResNet helper: 1×1 convolution with stride 2 (spatial downsampling via down2), over CHW images.

                                                                                                                                                                                            Instances For

                                                                                                                                                                                              ResNet helper: 3×3 convolution over batched images (NCHW-style), preserving spatial size.

                                                                                                                                                                                              Instances For

                                                                                                                                                                                                ResNet helper: 3×3 convolution over batched images (NCHW-style), downsampling via down2.

                                                                                                                                                                                                Instances For

                                                                                                                                                                                                  ResNet helper: 1×1 convolution over batched images (NCHW-style), preserving spatial size.

                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                    ResNet helper: 1×1 convolution over batched images (NCHW-style), downsampling via down2.

                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                      ResNet-style "basic block" configuration (CHW layout).

                                                                                                                                                                                                      PyTorch reference (conceptual): torchvision.models.resnet.BasicBlock (see https://pytorch.org/vision/stable/models/resnet.html).

                                                                                                                                                                                                      • outC :

                                                                                                                                                                                                        Number of output channels produced by the block.

                                                                                                                                                                                                      • downsample : Bool

                                                                                                                                                                                                        If true, use stride-2 downsampling + projection shortcut; otherwise preserve spatial dims.

                                                                                                                                                                                                      • activation : Activation

                                                                                                                                                                                                        Activation used inside the block (and after the residual addition).

                                                                                                                                                                                                      • seedBase :

                                                                                                                                                                                                        Base seed used to derive deterministic per-layer seeds inside the block.

                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                        ResNet-style "basic block" configuration (CHW layout).

                                                                                                                                                                                                        This public building block follows the standard ResNet basic-block pattern: conv3x3 -> BN -> act -> conv3x3 -> BN with a residual/skip connection.

                                                                                                                                                                                                        PyTorch references (for the conceptual shape):

                                                                                                                                                                                                        • Torchvision ResNet: https://pytorch.org/vision/stable/models/resnet.html
                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                          ResNet-18 style BasicBlock over batched image tensors (N×C×H×W).

                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                            Config record for transformerEncoderBlock.

                                                                                                                                                                                                            Separating the config as a structure makes it easier to write readable examples and keep seed management deterministic.

                                                                                                                                                                                                            • numHeads :

                                                                                                                                                                                                              Number of attention heads.

                                                                                                                                                                                                            • headDim :

                                                                                                                                                                                                              Per-head embedding dimension.

                                                                                                                                                                                                            • ffnHidden :

                                                                                                                                                                                                              Hidden dimension of the feed-forward network.

                                                                                                                                                                                                            • activation : Activation

                                                                                                                                                                                                              Activation used in the feed-forward network.

                                                                                                                                                                                                            • dropout? : Option Float

                                                                                                                                                                                                              Optional dropout probability for examples; none means no dropout.

                                                                                                                                                                                                            • seedBase :

                                                                                                                                                                                                              Base seed used to derive deterministic per-layer seeds inside the block.

                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                              Transformer encoder block configuration.

                                                                                                                                                                                                              This follows the familiar pattern: (residual MHA) -> LayerNorm -> (residual FFN) -> LayerNorm.

                                                                                                                                                                                                              PyTorch analogue:

                                                                                                                                                                                                              • torch.nn.TransformerEncoderLayer (https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html)
                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                Transformer encoder block.

                                                                                                                                                                                                                This is transformerEncoderBlockWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                  Config record for transformerEncoderStack.

                                                                                                                                                                                                                  This builds layers copies of transformerEncoderBlock, allocating seeds in a fixed stride.

                                                                                                                                                                                                                  • layers :

                                                                                                                                                                                                                    Layer stack.

                                                                                                                                                                                                                  • Template config for each block (its seedBase is ignored; we allocate per-layer seeds).

                                                                                                                                                                                                                  • seedBase :

                                                                                                                                                                                                                    Base seed for the whole stack.

                                                                                                                                                                                                                  • seedStride :

                                                                                                                                                                                                                    Seed stride between consecutive blocks (must exceed the per-block seed footprint).

                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                    def NN.API.nn.pure.blocks.transformerStackGoWithMask {batch n dModel : } [NeZero n] [NeZero dModel] (template : TransformerEncoderBlock) (seedBase seedStride : ) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) (layerIdx remaining : ) :

                                                                                                                                                                                                                    Internal recursion for transformerEncoderStack.

                                                                                                                                                                                                                    Builds remaining blocks starting at layerIdx, allocating each block's seedBase as seedBase + layerIdx * seedStride.

                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                      def NN.API.nn.pure.blocks.transformerStackGo {batch n dModel : } [NeZero n] [NeZero dModel] (template : TransformerEncoderBlock) (seedBase seedStride layerIdx remaining : ) :

                                                                                                                                                                                                                      Internal recursion for transformerEncoderStack (unmasked).

                                                                                                                                                                                                                      This is transformerStackGoWithMask with mask := none.

                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                        Stack cfg.layers copies of blocks.transformerEncoderBlock.

                                                                                                                                                                                                                        This is the TorchLean analogue of composing torch.nn.TransformerEncoderLayer into a torch.nn.TransformerEncoder (modulo the fact that TorchLean uses Seq composition).

                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                          Stack cfg.layers copies of blocks.transformerEncoderBlock.

                                                                                                                                                                                                                          This is transformerEncoderStackWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                            Transformer encoder followed by a flatten+linear classification head.

                                                                                                                                                                                                                            PyTorch analogue (roughly): nn.TransformerEncoder(...) + pooling/flattening + nn.Linear.

                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                              def NN.API.nn.pure.heads.classifier {s : Spec.Shape} (classes : ) (seedW seedB : := 0) :

                                                                                                                                                                                                                              Classification head: Flatten -> Linear.

                                                                                                                                                                                                                              This is a small convenience wrapper around nn.flattenLinear.

                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                def NN.API.nn.pure.heads.regressor {s : Spec.Shape} (outDim : := 1) (seedW seedB : := 0) :

                                                                                                                                                                                                                                Regression head: Flatten -> Linear with outDim outputs.

                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                  def NN.API.nn.pure.heads.classifierBatch {n : } {s : Spec.Shape} (classes : ) (seedW seedB : := 0) :

                                                                                                                                                                                                                                  Flatten(start_dim=1) -> Linear head for batched tensors.

                                                                                                                                                                                                                                  Input: N × σ Output: Mat N classes

                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                    def NN.API.nn.pure.heads.regressorBatch {n : } {s : Spec.Shape} (outDim : := 1) (seedW seedB : := 0) :

                                                                                                                                                                                                                                    Batched regression head: Flatten(start_dim=1) -> Linear(_, outDim) producing Mat N outDim.

                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                      Optimizer configs for the high-level training helpers.

                                                                                                                                                                                                                                      These mirror common PyTorch optimizers (by name and default hyperparameters), but they produce a TorchLean trainer config rather than a mutable optimizer object.

                                                                                                                                                                                                                                      PyTorch references:

                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                      Optimizer hyperparameter configuration for the supervised training helpers.

                                                                                                                                                                                                                                      We keep this small for examples and lightweight trainers. It mirrors a few common PyTorch optimizers by name/defaults, but it does not try to cover the full option surface of torch.optim.*.

                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                        SGD optimizer config.

                                                                                                                                                                                                                                        PyTorch analogue: torch.optim.SGD (https://pytorch.org/docs/stable/generated/torch.optim.SGD.html).

                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                          Momentum SGD optimizer config (PyTorch-style default momentum = 0.9).

                                                                                                                                                                                                                                          This is just sgd lr momentum with a different default.

                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                            @[reducible, inline]

                                                                                                                                                                                                                                            Adam optimizer config with standard defaults.

                                                                                                                                                                                                                                            PyTorch analogue: torch.optim.Adam (https://pytorch.org/docs/stable/generated/torch.optim.Adam.html).

                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                              @[reducible, inline]

                                                                                                                                                                                                                                              AdamW optimizer config with standard defaults (PyTorch-style weightDecay = 0.01).

                                                                                                                                                                                                                                              PyTorch analogue: torch.optim.AdamW (https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html).

                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                @[reducible, inline]

                                                                                                                                                                                                                                                Reduction mode for losses that start as elementwise tensors.

                                                                                                                                                                                                                                                PyTorch analogy: reduction="mean" or reduction="sum".

                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                  High-level training helpers.

                                                                                                                                                                                                                                                  This namespace is designed for executable demos: it wires together

                                                                                                                                                                                                                                                  It stays intentionally lightweight: rather than hiding everything behind a large framework, it exposes a small set of default building blocks so tutorials can focus on models and verification.

                                                                                                                                                                                                                                                  PyTorch Mapping #

                                                                                                                                                                                                                                                  These helpers correspond to the training loop code you would typically write around:

                                                                                                                                                                                                                                                  @[reducible, inline]
                                                                                                                                                                                                                                                  abbrev NN.API.train.Task (σ τ : Spec.Shape) :

                                                                                                                                                                                                                                                  A supervised task is just a model plus a choice of loss.

                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                                                                    A fully instantiated supervised task runner.

                                                                                                                                                                                                                                                    This bundles:

                                                                                                                                                                                                                                                    • the imperative ScalarModule (parameters/buffers stored in refs),
                                                                                                                                                                                                                                                    • compiled predictors and loss functions for both .train and .eval modes (so switching mode is cheap),
                                                                                                                                                                                                                                                    • and the current mode stored in an IO.Ref.

                                                                                                                                                                                                                                                    The mode influences both operator behavior (e.g. dropout/batchnorm) and whether buffers are updated during training.

                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                                      Stateful training loop object: a Runner plus an optimizer state and a step counter.

                                                                                                                                                                                                                                                      This is the TorchLean analogue of holding a PyTorch optimizer object plus the model, ready to step() on batches.

                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                                        Step-based training configuration for fit / fitDataset.

                                                                                                                                                                                                                                                        Fields:

                                                                                                                                                                                                                                                        • steps: number of parameter updates,
                                                                                                                                                                                                                                                        • optimizer: optimizer hyperparameters,
                                                                                                                                                                                                                                                        • scheduler: optional learning-rate schedule (applied per step),
                                                                                                                                                                                                                                                        • logEvery: progress printing frequency (0 disables logging).
                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                                          Epoch-based training configuration for fitLoader (data-loader training).

                                                                                                                                                                                                                                                          Fields:

                                                                                                                                                                                                                                                          • epochs: number of epochs (each epoch iterates once over the loader),
                                                                                                                                                                                                                                                          • optimizer: optimizer hyperparameters,
                                                                                                                                                                                                                                                          • scheduler: optional learning-rate schedule (applied per step/epoch depending on helper),
                                                                                                                                                                                                                                                          • logEvery: progress printing frequency (0 disables logging).
                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                            @[reducible, inline]

                                                                                                                                                                                                                                                            Small summary returned by fit* helpers.

                                                                                                                                                                                                                                                            By default, before and after are mean loss values, but the type is polymorphic so callers can report other scalars in the same shape.

                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                              Most of API.train.* is just a public re-export of TorchLean.Trainer.*.

                                                                                                                                                                                                                                                              We use export (rather than rewriting 1-line forwarders) so this file stays small and avoids duplicating implementation details at the facade layer.

                                                                                                                                                                                                                                                              Metric Artifacts #

                                                                                                                                                                                                                                                              The public training facade also exposes TorchLean's lightweight metric artifact format. This is the local equivalent of “log scalars during a run, then inspect them later”: write a JSON TrainLog, view it with the training widgets, or adapt the JSON to an external tracker such as Weights & Biases.

                                                                                                                                                                                                                                                              A runner bundled with the task that created it.

                                                                                                                                                                                                                                                              This is an ergonomic wrapper around Runner α task: it remembers the dependent task, so tutorial code can call tr.predict x, tr.fit cfg samples, etc. without repeatedly writing (task := task).

                                                                                                                                                                                                                                                              • task : Task σ τ

                                                                                                                                                                                                                                                                The supervised task: model plus loss.

                                                                                                                                                                                                                                                              • runner : Runner α self.task

                                                                                                                                                                                                                                                                The instantiated runner for task.

                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                def NN.API.train.TaskRunner.ofRunner {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (runner : Runner α task) :
                                                                                                                                                                                                                                                                TaskRunner σ τ α

                                                                                                                                                                                                                                                                Bundle an existing runner with its task.

                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                  Get the current model parameters from a bundled runner.

                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                    Read the current mode (.train or .eval) from a bundled runner.

                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                      Set the mode (.train or .eval) on a bundled runner.

                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                        Switch a bundled runner to training mode.

                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                          Switch a bundled runner to evaluation mode.

                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                            Check whether a bundled runner is in training mode.

                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                              Predict on one input tensor using the bundled runner's active mode.

                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                Predict on a list of inputs using the bundled runner's active mode.

                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                  Mean loss over an entire dataset for a bundled runner.

                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                    def NN.API.train.TaskRunner.fit {σ τ : Spec.Shape} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [Runtime.Scalar α] (tr : TaskRunner σ τ α) (cfg : FitConfig) (samples : List (sample.Supervised α σ τ)) :

                                                                                                                                                                                                                                                                                    Fit a bundled runner on an explicit list of samples for a fixed number of steps.

                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                      Fit a bundled runner on a Dataset for a fixed number of steps.

                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                        Fit a bundled runner using a DataLoader for a fixed number of epochs.

                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                          def NN.API.train.runTask {σ τ : Spec.Shape} (task : Task σ τ) (args : List String) (k : {α : Type} → [inst : Semantics.Scalar α] → [inst_1 : DecidableEq Spec.Shape] → [ToString α] → [Runtime.Scalar α] → TaskRunner σ τ αList StringIO Unit) :

                                                                                                                                                                                                                                                                                          CLI-oriented runner entry point that passes a bundled TaskRunner to the continuation.

                                                                                                                                                                                                                                                                                          This mirrors train.run, but removes the need to keep threading (task := task) after instantiation.

                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                            def NN.API.train.accuracyOneHotBatched {σ : Spec.Shape} {classes batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch (Tensor.Shape.Vec classes))} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (samples : List (sample.Batch α batch σ (Tensor.Shape.Vec classes))) :

                                                                                                                                                                                                                                                                                            Count correct predictions in a one-hot labeled batched dataset.

                                                                                                                                                                                                                                                                                            This is the minibatch analogue of accuracyOneHot: the task already has a leading dim0 batch axis, so we score each row of the batch independently and accumulate totals.

                                                                                                                                                                                                                                                                                            Returns (correct, total) where total = batch * numBatches.

                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                              def NN.API.train.meanLossDataset {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (runner : Runner α task) (dataset : Runtime.Autograd.Train.Dataset (sample.Supervised α σ τ)) :
                                                                                                                                                                                                                                                                                              IO α

                                                                                                                                                                                                                                                                                              Mean loss over an entire dataset (useful for quick before/after reports).

                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                def NN.API.train.fit {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [Runtime.Scalar α] (runner : Runner α task) (cfg : FitConfig) (samples : List (sample.Supervised α σ τ)) :

                                                                                                                                                                                                                                                                                                Fit on an explicit list of samples for a fixed number of steps.

                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                  def NN.API.train.fitDataset {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [Runtime.Scalar α] (runner : Runner α task) (cfg : FitConfig) (dataset : Runtime.Autograd.Train.Dataset (sample.Supervised α σ τ)) :

                                                                                                                                                                                                                                                                                                  Fit on a Dataset for a fixed number of steps.

                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                    Fit using a DataLoader for a fixed number of epochs.

                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                      structure NN.API.train.StepEvent (α : Type) :

                                                                                                                                                                                                                                                                                                      Callback event fired after each training step.

                                                                                                                                                                                                                                                                                                      • epoch :

                                                                                                                                                                                                                                                                                                        Current epoch number.

                                                                                                                                                                                                                                                                                                      • step :

                                                                                                                                                                                                                                                                                                        Global optimizer-step counter.

                                                                                                                                                                                                                                                                                                      • loss : α

                                                                                                                                                                                                                                                                                                        Loss reported for this step.

                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                        Callback event fired at the end of an epoch (how many steps ran).

                                                                                                                                                                                                                                                                                                        • epoch :

                                                                                                                                                                                                                                                                                                          Epoch number that just completed.

                                                                                                                                                                                                                                                                                                        • steps :

                                                                                                                                                                                                                                                                                                          Number of steps executed in the epoch.

                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                          structure NN.API.train.Callbacks (α : Type) :

                                                                                                                                                                                                                                                                                                          Hooks for instrumenting fitLoaderBatched-style training loops.

                                                                                                                                                                                                                                                                                                          These are lightweight by design (IO callbacks). If you want richer logging, consider building a wrapper in your own project that translates these events into structured JSON/metrics.

                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                            No-op callbacks.

                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                              Combine two callback collections by running them in sequence.

                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                @[implicit_reducible]

                                                                                                                                                                                                                                                                                                                for callbacks: a no-op callback collection.

                                                                                                                                                                                                                                                                                                                @[implicit_reducible]

                                                                                                                                                                                                                                                                                                                Callbacks form a monoid under sequential composition.

                                                                                                                                                                                                                                                                                                                def NN.API.train.onTrainStart {α : Type} (action : IO Unit) :

                                                                                                                                                                                                                                                                                                                Build callbacks that run action once at the start of training.

                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                  def NN.API.train.onStep {α : Type} (f : StepEvent αIO Unit) :

                                                                                                                                                                                                                                                                                                                  Build callbacks that observe every training step.

                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                    Build callbacks that run at the end of each epoch.

                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                      Build callbacks that run once at the end of training, with the final report.

                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                        def NN.API.train.logLossEvery {α : Type} [ToString α] (every : := 1) :

                                                                                                                                                                                                                                                                                                                        Callback helper: log the loss every every steps (if every > 0).

                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                          def NN.API.train.withMode {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] {β : Type} (runner : Runner α task) (value : TorchLean.NN.Mode) (action : IO β) :
                                                                                                                                                                                                                                                                                                                          IO β

                                                                                                                                                                                                                                                                                                                          Run an action with the runner temporarily switched to value mode.

                                                                                                                                                                                                                                                                                                                          This is useful for "evaluate on a validation set during training" in callback-based loops.

                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                            def NN.API.train.meanLossModuleLoader {σ τ : Spec.Shape} {n : } {paramShapes : List Spec.Shape} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (module : Runtime.Autograd.TorchLean.ScalarModule α paramShapes [Spec.Shape.dim n σ, Spec.Shape.dim n τ]) (loader : Data.BatchLoader α n σ τ) :
                                                                                                                                                                                                                                                                                                                            IO α

                                                                                                                                                                                                                                                                                                                            Mean loss for an already-instantiated scalar module over a typed minibatch loader.

                                                                                                                                                                                                                                                                                                                            This is the general streaming evaluation path used by the runtime examples. It is deliberately not CIFAR-specific: any supervised task whose loss module consumes [dim n σ, dim n τ] can use the same loader. The loader stores ordinary per-example samples (x : σ, y : τ); this helper asks Data.epoch for raw minibatches and calls Data.collateSupervised to build one shape-typed batch at a time.

                                                                                                                                                                                                                                                                                                                            Two details are important for larger examples:

                                                                                                                                                                                                                                                                                                                            • We force shuffle := false for evaluation so before/after metrics are deterministic.
                                                                                                                                                                                                                                                                                                                            • We do not call Data.BatchLoader.batchDataset, because that would materialize every collated minibatch at once. Streaming keeps the same API usable for image, sequence, and scientific ML examples where the batch tensors are much larger than small tabular datasets.
                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                              def NN.API.train.meanLossBatchLoader {σ τ : Spec.Shape} {n : } {task : Task (Spec.Shape.dim n σ) (Spec.Shape.dim n τ)} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (runner : Runner α task) (loader : Data.BatchLoader α n σ τ) :
                                                                                                                                                                                                                                                                                                                              IO α

                                                                                                                                                                                                                                                                                                                              Mean loss over a typed minibatch loader through a train.Runner.

                                                                                                                                                                                                                                                                                                                              This is the runner-facing wrapper around meanLossModuleLoader. Use it when the example is built around train.run, task modes, and the proof-facing trainer abstraction. Use meanLossModuleLoader directly when the example has already instantiated a runtime TorchLean.Module.ScalarModule, which is the common fast path for CUDA demos.

                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                def NN.API.train.accuracyOneHotBatchLoader {σ : Spec.Shape} {classes batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch (Tensor.Shape.Vec classes))} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (loader : Data.BatchLoader α batch σ (Tensor.Shape.Vec classes)) :

                                                                                                                                                                                                                                                                                                                                One-hot accuracy over a typed minibatch loader without materializing all collated batches.

                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                  def NN.API.train.fitModuleLoaderWith {σ τ : Spec.Shape} {n : } {paramShapes : List Spec.Shape} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (module : Runtime.Autograd.TorchLean.ScalarModule α paramShapes [Spec.Shape.dim n σ, Spec.Shape.dim n τ]) (optimizer : Runtime.Autograd.TorchLean.Optimizer α paramShapes) (epochs : ) (loader : Data.BatchLoader α n σ τ) (callbacks : Callbacks α := Callbacks.empty) :

                                                                                                                                                                                                                                                                                                                                  Train a runtime scalar module from a typed minibatch loader.

                                                                                                                                                                                                                                                                                                                                  This is the shared "real epoch loop" for model examples that instantiate a module directly with TorchLean.Module.instantiateWithOptions, including CUDA runs. It mirrors the PyTorch structure:

                                                                                                                                                                                                                                                                                                                                  1. create an optimizer state for the module parameters;
                                                                                                                                                                                                                                                                                                                                  2. for each epoch, ask the general Data.batchLoader for shuffled raw batches;
                                                                                                                                                                                                                                                                                                                                  3. collate each raw batch into a shape-typed (xBatch, yBatch) sample;
                                                                                                                                                                                                                                                                                                                                  4. report the scalar loss through callbacks;
                                                                                                                                                                                                                                                                                                                                  5. run forward/backward/optimizer.step through TorchLean.Module.stepWith.

                                                                                                                                                                                                                                                                                                                                  The function is polymorphic in the input shape σ, target shape τ, batch size n, scalar type α, parameter shapes, and optimizer. It is not an image-specific helper. CNN, ResNet, ViT, MLP, sequence, operator-learning, and future model demos should all be able to use this path whenever their supervised loss module has input shapes [dim n σ, dim n τ].

                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                    def NN.API.train.fitLoaderWith {σ τ : Spec.Shape} {n : } {task : Task (Spec.Shape.dim n σ) (Spec.Shape.dim n τ)} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [Runtime.Scalar α] (runner : Runner α task) (cfg : LoaderFitConfig) (loader : Data.BatchLoader α n σ τ) (callbacks : Callbacks α := Callbacks.empty) :

                                                                                                                                                                                                                                                                                                                                    Train from a runner-backed loader with explicit callbacks instead of inline printing in example code.

                                                                                                                                                                                                                                                                                                                                    This is the proof/trainer-facing public escape hatch for PyTorch-style custom loops:

                                                                                                                                                                                                                                                                                                                                    • keep the optimizer/scheduler logic in the library,
                                                                                                                                                                                                                                                                                                                                    • inject logging, evaluation, and probe reporting through callbacks.

                                                                                                                                                                                                                                                                                                                                    This path keeps the Runner abstraction, including task modes and scheduler support. For CUDA-heavy tutorials that already have a TorchLean.Module.ScalarModule, prefer fitModuleLoaderWith; both paths consume the same general API.Data.batchLoader.

                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                      Public minibatch training path.

                                                                                                                                                                                                                                                                                                                                      data.batchLoader produces a typed BatchLoader (with a type-level batch size n), and this helper bridges from an untyped runtime loader into the typed training loop.

                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                        def NN.API.train.stepper {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [Runtime.Scalar α] (runner : Runner α task) (optimizer : optim.Optimizer) (scheduler : Option TorchLean.Schedulers.Config := none) :
                                                                                                                                                                                                                                                                                                                                        IO (Stepper α task)

                                                                                                                                                                                                                                                                                                                                        Create a Stepper loop for a runner and optimizer (optionally with an LR scheduler).

                                                                                                                                                                                                                                                                                                                                        This corresponds to the “inner training loop” state in typical PyTorch code: an optimizer state plus (optional) schedule state, ready to step on a batch.

                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                          def NN.API.train.step {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (loop : Stepper α task) (sample : sample.Supervised α σ τ) :
                                                                                                                                                                                                                                                                                                                                          IO α

                                                                                                                                                                                                                                                                                                                                          Run one optimization step on a single supervised sample (one batch).

                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                            def NN.API.train.epoch {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (loop : Stepper α task) (samples : List (sample.Supervised α σ τ)) :
                                                                                                                                                                                                                                                                                                                                            IO (List α)

                                                                                                                                                                                                                                                                                                                                            Run one epoch over a list of supervised samples, returning the per-step losses.

                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                              Small Reporting Helpers (IO) #

                                                                                                                                                                                                                                                                                                                                              These helpers keep tutorial code readable by factoring out common "print a loss/accuracy table" patterns. They do not affect semantics: they only call the underlying train.* functions and print human-facing summaries.

                                                                                                                                                                                                                                                                                                                                              def NN.API.train.Report.reportProbes {β : Type} (title : String) (probes : List β) (lineOf : βIO String) :

                                                                                                                                                                                                                                                                                                                                              Print a titled list of probe lines.

                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                def NN.API.train.Report.reportMeanLoss {σ τ : Spec.Shape} {task : Task σ τ} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (runner : Runner α task) (dataset : Runtime.Autograd.Train.Dataset (sample.Supervised α σ τ)) (label : String) :

                                                                                                                                                                                                                                                                                                                                                Convenience: mean loss on a dataset, printed with a label.

                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                  def NN.API.train.Report.reportMeanLossLoader {σ τ : Spec.Shape} {batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch τ)} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (runner : Runner α task) (loader : Data.BatchLoader α batch σ τ) (label : String) :

                                                                                                                                                                                                                                                                                                                                                  Convenience: mean loss on a typed minibatch loader, streamed batch by batch.

                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                    def NN.API.train.Report.reportMeanLossModuleLoader {σ τ : Spec.Shape} {batch : } {paramShapes : List Spec.Shape} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] (module : Runtime.Autograd.TorchLean.ScalarModule α paramShapes [Spec.Shape.dim batch σ, Spec.Shape.dim batch τ]) (loader : Data.BatchLoader α batch σ τ) (label : String) :

                                                                                                                                                                                                                                                                                                                                                    Convenience: mean loss on a typed minibatch loader for an already-instantiated runtime module.

                                                                                                                                                                                                                                                                                                                                                    Use this in direct CUDA/runtime examples to avoid building a Runner only for logging. The data path is still the same public loader path: Data.batchLoader plus Data.collateSupervised.

                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                      def NN.API.train.Report.reportClassProbes {σ : Spec.Shape} {classes : } {task : Task σ (Spec.Shape.dim classes Spec.Shape.scalar)} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (probes : List (String × Spec.Tensor α σ × )) (title : String := "predictions") (includeLogits : Bool := false) :

                                                                                                                                                                                                                                                                                                                                                      Report predicted classes on a list of named probes.

                                                                                                                                                                                                                                                                                                                                                      Each probe entry is (name, x, expectedClass). If includeLogits := true, also prints the raw model outputs.

                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                        def NN.API.train.Report.reportClassProbesBatchedFromSingle {σ : Spec.Shape} {classes batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch (Tensor.Shape.Vec classes))} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (probes : List (String × Spec.Tensor α σ × )) (title : String := "predictions") (includeLogits : Bool := false) :

                                                                                                                                                                                                                                                                                                                                                        Report predicted classes on a list of named probes, for a batched model.

                                                                                                                                                                                                                                                                                                                                                        This expects probes of the unbatched input shape σ and replicates each probe across the batch axis, then reports the prediction for row 0.

                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                          def NN.API.train.Report.reportLossAccuracyOneHot {σ : Spec.Shape} {classes : } {task : Task σ (Spec.Shape.dim classes Spec.Shape.scalar)} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (dataset : Runtime.Autograd.Train.Dataset (sample.Supervised α σ (Spec.Shape.dim classes Spec.Shape.scalar))) (label : String) :

                                                                                                                                                                                                                                                                                                                                                          Convenience: mean loss + one-hot accuracy on a dataset, printed with a label.

                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                            def NN.API.train.Report.reportLossAccuracyOneHotBatched {σ : Spec.Shape} {classes batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch (Tensor.Shape.Vec classes))} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (dataset : Runtime.Autograd.Train.Dataset (sample.Batch α batch σ (Tensor.Shape.Vec classes))) (label : String) :

                                                                                                                                                                                                                                                                                                                                                            Batched variant of reportLossAccuracyOneHot.

                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                              def NN.API.train.Report.reportLossAccuracyOneHotLoader {σ : Spec.Shape} {classes batch : } {task : Task (Spec.Shape.dim batch σ) (Spec.Shape.dim batch (Tensor.Shape.Vec classes))} {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] [ToString α] [Add α] [Div α] [Zero α] [Coe α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] (runner : Runner α task) (loader : Data.BatchLoader α batch σ (Tensor.Shape.Vec classes)) (label : String) :

                                                                                                                                                                                                                                                                                                                                                              Loader variant of reportLossAccuracyOneHotBatched, streaming through minibatches.

                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                Model Builders and Seeding #

                                                                                                                                                                                                                                                                                                                                                                TorchLean keeps initialization randomness explicit so examples are reproducible.

                                                                                                                                                                                                                                                                                                                                                                Typical patterns:

                                                                                                                                                                                                                                                                                                                                                                1. Explicit seeds (best for proofs / reproducibility-sensitive code):

                                                                                                                                                                                                                                                                                                                                                                  • build with nn.pure.linear ... (seedW := ...) (seedB := ...) etc
                                                                                                                                                                                                                                                                                                                                                                  • compose with seq! ... / >>>
                                                                                                                                                                                                                                                                                                                                                                2. Script-style “manual seed once”:

                                                                                                                                                                                                                                                                                                                                                                Note: nn.Sequential lives in Type 2, so it cannot be returned directly from IO. We keep model building pure by drawing a base seed in IO and then calling nn.run.

                                                                                                                                                                                                                                                                                                                                                                PyTorch-like global seeding convenience for seeded model builders.

                                                                                                                                                                                                                                                                                                                                                                This sets the global seed stream used by nn.runGlobal / nn.nextSeed.

                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                  @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                  Embedding table initialization configuration (one-hot / token-distribution inputs).

                                                                                                                                                                                                                                                                                                                                                                  This is the TorchLean-friendly analogue of torch.nn.Embedding in the common demo setting where token ids are represented as one-hot vectors (or soft token distributions), so lookup is a matrix multiplication rather than integer indexing.

                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                    Learned positional embedding configuration.

                                                                                                                                                                                                                                                                                                                                                                    This is a trainable parameter tensor of shape (seqLen × embedDim) that is broadcast across the leading batch dimension and added to the input.

                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                      Sinusoidal positional encoding configuration.

                                                                                                                                                                                                                                                                                                                                                                      This is the classic (non-trainable) Transformer sinusoidal encoding, added to token embeddings. startPos is an absolute-position offset (useful for KV-cache decoding).

                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                        Rotary positional embedding (RoPE) configuration.

                                                                                                                                                                                                                                                                                                                                                                        startPos is an absolute-position offset (useful for KV-cache decoding).

                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                          Named-field Conv2d configuration (CHW layout).

                                                                                                                                                                                                                                                                                                                                                                          This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                            @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                            Named-field Conv2d configuration (CHW layout).

                                                                                                                                                                                                                                                                                                                                                                            This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                              @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                              MaxPool2d configuration for CHW inputs.

                                                                                                                                                                                                                                                                                                                                                                              PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                MaxPool2d configuration for CHW inputs.

                                                                                                                                                                                                                                                                                                                                                                                PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                  @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                  AvgPool2d configuration for CHW inputs.

                                                                                                                                                                                                                                                                                                                                                                                  PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                    AvgPool2d configuration for CHW inputs.

                                                                                                                                                                                                                                                                                                                                                                                    PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                      LayerNorm configuration for batched (batch x seqLen x embedDim) tensors.

                                                                                                                                                                                                                                                                                                                                                                                      PyTorch analogue: torch.nn.LayerNorm. See https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.

                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                        RMSNorm configuration for batched (batch x seqLen x embedDim) tensors.

                                                                                                                                                                                                                                                                                                                                                                                        This is a common alternative to LayerNorm in modern transformer architectures.

                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                          BatchNorm2d configuration (learned scale/shift).

                                                                                                                                                                                                                                                                                                                                                                                          PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                            @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                            InstanceNorm2d configuration (learned scale/shift).

                                                                                                                                                                                                                                                                                                                                                                                            PyTorch analogue: torch.nn.InstanceNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html.

                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                              @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                              Multi-head self-attention configuration.

                                                                                                                                                                                                                                                                                                                                                                                              PyTorch analogue: torch.nn.MultiheadAttention (conceptually). See https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html.

                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                def NN.API.nn.globalAvgPoolCHW (c h w : ) {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                                                                                                                                                                                                                                                                                                                                                                                Global average pooling over a CHW tensor.

                                                                                                                                                                                                                                                                                                                                                                                                PyTorch analogue: torch.nn.AdaptiveAvgPool2d((1, 1)) followed by flattening.

                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.nn.globalAvgPoolNCHW (n c h w : ) {hN : n > 0} {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                                                                                                                                                                                                                                                                                                                                                                                  Global average pooling over an NCHW tensor (preserves the batch dimension).

                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                    Seeded Builders (Default nn.*) #

                                                                                                                                                                                                                                                                                                                                                                                                    For end-user code, the default nn.* layer constructors allocate initialization seeds automatically via nn.M (a deterministic seed-stream builder).

                                                                                                                                                                                                                                                                                                                                                                                                    Use nn.pure.* when you want to pass explicit seeds (proof-friendly / fully reproducible).

                                                                                                                                                                                                                                                                                                                                                                                                    @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                                    abbrev NN.API.nn.M (α : Type u_1) :
                                                                                                                                                                                                                                                                                                                                                                                                    Type u_1

                                                                                                                                                                                                                                                                                                                                                                                                    Seeded builder monad: a state monad over API.rand.SeedStream.

                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                      def NN.API.nn.run {α : Type 2} (seed : ) (x : M α) :
                                                                                                                                                                                                                                                                                                                                                                                                      α

                                                                                                                                                                                                                                                                                                                                                                                                      Run a seeded builder starting from a base seed.

                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                        def NN.API.nn.lift {α : Type 2} (x : α) :
                                                                                                                                                                                                                                                                                                                                                                                                        M α

                                                                                                                                                                                                                                                                                                                                                                                                        Lift a pure value into the seeded builder (consumes no seeds).

                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                          def NN.API.nn.withSeed {α : Type 2} (k : α) :
                                                                                                                                                                                                                                                                                                                                                                                                          M α

                                                                                                                                                                                                                                                                                                                                                                                                          Consume one fresh seed and pass it to k.

                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                            def NN.API.nn.withSeeds2 {α : Type 2} (k : α) :
                                                                                                                                                                                                                                                                                                                                                                                                            M α

                                                                                                                                                                                                                                                                                                                                                                                                            Consume two fresh seeds and pass them to k (in order).

                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                              Elementwise ReLU. PyTorch analogue: torch.nn.ReLU / torch.nn.functional.relu.

                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                Elementwise SiLU/Swish. PyTorch analogue: torch.nn.SiLU / torch.nn.functional.silu.

                                                                                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                  Elementwise GELU. PyTorch analogue: torch.nn.GELU / torch.nn.functional.gelu.

                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                    Elementwise sigmoid. PyTorch analogue: torch.nn.Sigmoid / torch.nn.functional.sigmoid.

                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                      Elementwise tanh. PyTorch analogue: torch.nn.Tanh / torch.nn.functional.tanh.

                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                        Softmax. PyTorch analogue: torch.nn.Softmax / torch.nn.functional.softmax.

                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                          Reduce-sum to a scalar. PyTorch analogue: torch.sum.

                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                            Flatten any tensor into a 1D vector of length size s. PyTorch analogue: torch.flatten.

                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                              Flatten a batched tensor N × σ into a matrix N × (size σ).

                                                                                                                                                                                                                                                                                                                                                                                                                              PyTorch analogue: torch.flatten(x, start_dim=1).

                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                Flatten a batched tensor starting at dimension 1 (keep dim0).

                                                                                                                                                                                                                                                                                                                                                                                                                                Synonym for flattenBatch, matching PyTorch’s start_dim=1 wording.

                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.nn.maxPool2d {n inC inH inW : } (cfg : MaxPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                  M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                  MaxPool2d using NeZero to hide nonzero kernel proofs.

                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                    def NN.API.nn.maxPool {n inC inH inW : } (cfg : MaxPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                    M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                    Max pooling over batched CHW images, allocating any required initialization seeds automatically.

                                                                                                                                                                                                                                                                                                                                                                                                                                    Shorthand for maxPool2d.

                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                      def NN.API.nn.avgPool2d {n inC inH inW : } (cfg : AvgPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                      M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                      AvgPool2d over batched NCHW inputs (shape N×C×H×W, like PyTorch).

                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                        def NN.API.nn.avgPool {n inC inH inW : } (cfg : AvgPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                        M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                        Average pooling over batched CHW images, allocating any required initialization seeds automatically.

                                                                                                                                                                                                                                                                                                                                                                                                                                        Shorthand for avgPool2d.

                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                          def NN.API.nn.linear (inDim outDim : ) (pfx : Spec.Shape := Spec.Shape.scalar) :
                                                                                                                                                                                                                                                                                                                                                                                                                                          M (Sequential (pfx.appendDim inDim) (pfx.appendDim outDim))

                                                                                                                                                                                                                                                                                                                                                                                                                                          Linear layer on the last axis (prefix-shape preserving).

                                                                                                                                                                                                                                                                                                                                                                                                                                          PyTorch analogue: torch.nn.Linear. See https://pytorch.org/docs/stable/generated/torch.nn.Linear.html.

                                                                                                                                                                                                                                                                                                                                                                                                                                          Unlike the lower-level TorchLean layer constructor (which is vector-only), this public facade matches PyTorch’s convention:

                                                                                                                                                                                                                                                                                                                                                                                                                                          • if x has shape [..., inDim], linear inDim outDim returns a model of shape [..., outDim].

                                                                                                                                                                                                                                                                                                                                                                                                                                          The leading “prefix” dimensions are treated as a batch (they are flattened to (numel(prefix), inDim), the affine map is applied once, and the result is reshaped back).

                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                            def NN.API.nn.linearV (inDim outDim : ) :

                                                                                                                                                                                                                                                                                                                                                                                                                                            Vector-only linear layer alias.

                                                                                                                                                                                                                                                                                                                                                                                                                                            This is shorthand for nn.linear inDim outDim at scalar prefix shape, so examples do not need to mention pfx := Spec.Shape.scalar.

                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                              def NN.API.nn.rnn (seqLen inputSize hiddenSize : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                              M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                                                                                                                                                                                                                                                                                                                                                              Vanilla RNN layer (time-major sequence, no batch axis).

                                                                                                                                                                                                                                                                                                                                                                                                                                              Semantics: h_t = tanh(W [x_t; h_{t-1}] + b), with h_{-1} = 0.

                                                                                                                                                                                                                                                                                                                                                                                                                                              This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                                                                                                                                                                                                                                                                                                                                                              PyTorch analogy: torch.nn.RNN(inputSize, hiddenSize, nonlinearity="tanh") with batch_first=false, specialized to a single batch element.

                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                def NN.API.nn.gru (seqLen inputSize hiddenSize : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                                                                                                                                                                                                                                                                                                                                                                GRU layer (time-major sequence, no batch axis).

                                                                                                                                                                                                                                                                                                                                                                                                                                                This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                                                                                                                                                                                                                                                                                                                                                                PyTorch analogy: torch.nn.GRU(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.nn.mamba (seqLen inputSize hiddenSize : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                  M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                                                                                                                                                                                                                                                                                                                                                                  Trainable Mamba-style gated diagonal state-space layer.

                                                                                                                                                                                                                                                                                                                                                                                                                                                  The layer is time-major and single-batch, matching the simple rnn/gru/lstm constructors: input (seqLen × inputSize), output (seqLen × hiddenSize). It is unrolled with differentiable TorchLean ops, so CPU and CUDA training use the same API.

                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                    def NN.API.nn.lstm (seqLen inputSize hiddenSize : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                    M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                                                                                                                                                                                                                                                                                                                                                                    LSTM layer (time-major sequence, no batch axis).

                                                                                                                                                                                                                                                                                                                                                                                                                                                    This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                                                                                                                                                                                                                                                                                                                                                                    PyTorch analogy: torch.nn.LSTM(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                      def NN.API.nn.conv2d {n inC inH inW : } (cfg : Conv2d) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                                      M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                                      2D convolution over a batched image tensor (shape N×C×H×W, like PyTorch).

                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                        def NN.API.nn.conv {n inC inH inW : } (cfg : Conv) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                                                                                                                                                                                                                                                                                                                                                                        M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1)))

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Convolution over batched CHW images, allocating initialization seeds automatically.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Shorthand for conv2d.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                          def NN.API.nn.batchNorm2d {n c h w : } (cfg : BatchNorm2d := { }) [NeZero n] [NeZero c] [NeZero h] [NeZero w] :

                                                                                                                                                                                                                                                                                                                                                                                                                                                          BatchNorm2d over NCHW inputs, using NeZero to hide the positivity proofs.

                                                                                                                                                                                                                                                                                                                                                                                                                                                          PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                            def NN.API.nn.batchNorm {n c h w : } (cfg : BatchNorm2d := { }) [NeZero n] [NeZero c] [NeZero h] [NeZero w] :

                                                                                                                                                                                                                                                                                                                                                                                                                                                            BatchNorm over batched CHW images, allocating initialization seeds automatically.

                                                                                                                                                                                                                                                                                                                                                                                                                                                            Shorthand for batchNorm2d.

                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                              def NN.API.nn.embedding (vocab embedDim : ) (cfg : Embedding := { }) {pfx : Spec.Shape} :
                                                                                                                                                                                                                                                                                                                                                                                                                                                              M (Sequential (pfx.appendDim vocab) (pfx.appendDim embedDim))

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Embedding layer for one-hot / token-distribution inputs (no bias).

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Input shape: [..., vocab] Output shape: [..., embedDim]

                                                                                                                                                                                                                                                                                                                                                                                                                                                              PyTorch analogue: conceptually nn.Embedding(vocab, embedDim) but applied to one-hot inputs.

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                def NN.API.nn.sinusoidalPositionalEncoding {batch seqLen embedDim : } (cfg : SinusoidalPositionalEncoding := { }) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                Add sinusoidal positional encodings to a batched (batch × seqLen × embedDim) tensor.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                Implementation:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                • precompute PE : (seqLen × embedDim) at initialization time (stored as a non-trainable buffer),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                • broadcast it across the leading batch axis and add to the input.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.nn.rope {batch numHeads seqLen headDim : } (cfg : RoPE := { }) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                  M (Sequential (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim))) (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim))))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Apply RoPE to a batched multi-head tensor (batch × numHeads × seqLen × headDim).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  This matches the standard identity:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  rope(x) = x * cos + rotatePairs(x) * sin

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  where cos/sin depend only on (pos, dim) and broadcast across (batch, numHeads).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Notes:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  • This layer is differentiable (gradients flow through the rotation), but it has no trainable parameters; the precomputed cos/sin tables are stored as non-trainable buffers.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                  • The pure spec version is in NN.Spec.Layers.PositionalEncoding (Spec.rope_apply_heads_spec).
                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                    def NN.API.nn.learnedPositionalEmbedding {batch seqLen embedDim : } (cfg : LearnedPositionalEmbedding := { }) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                    M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Add learned positional embeddings to a batched (batch × seqLen × embedDim) tensor.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    PyTorch analogue: x + pos[:seqLen] where pos is a parameter table.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                      def NN.API.nn.layerNorm {batch seqLen embedDim : } [NeZero seqLen] [NeZero embedDim] :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                      M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Layer normalization over (batch × seqLen × embedDim) tensors.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      This normalizes each embedDim-vector (per batch element, per sequence position), and applies learned affine parameters gamma and beta.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      PyTorch analogue: torch.nn.LayerNorm(embedDim) on a tensor shaped (batch, seqLen, embedDim).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Implementation note: TorchLean uses NeZero to ensure seqLen and embedDim are positive, avoiding degenerate shapes.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Multi-head self-attention using NeZero to hide the nonzero sequence length proof.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        If mask is provided, it is a boolean attention mask of shape (n × n) (e.g. causal masking).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Transformer encoder block.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          This is transformerEncoderBlockWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Stack cfg.layers copies of blocks.transformerEncoderBlock.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            This is transformerEncoderStackWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                              ResNet-18 style BasicBlock over batched image tensors (N×C×H×W).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Dropout layer (active in train mode, identity in eval mode).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                PyTorch analogue: torch.nn.Dropout.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.nn.runGlobal {α : Type} (x : M α) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  IO α

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Run a seeded builder using the global seed stream set by nn.manualSeed (results in Type).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Note: model values like nn.Sequential live in Type 2, so they cannot be returned from IO. For models, use nn.run with an explicit base seed (obtained from nn.nextSeed).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Draw a fresh base seed from the global seed stream set by nn.manualSeed.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Draw n fresh base seeds from the global seed stream.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Naming Convenience #

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        nn.run / nn.nextSeed are the core primitives, but in user code it is often clearer to read:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        abbrev NN.API.nn.build {α : Type 2} (seed : ) (x : M α) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        α

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Alias for nn.run (PyTorch-style wording: build/init a model from a base seed).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Alias for nn.nextSeed (draw a fresh base seed from the global seed stream).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            def NN.API.nn.withModel {σ τ : Spec.Shape} {β : Type} (mk : M (Sequential σ τ)) (k : Sequential σ τIO β) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            IO β

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Build a model using the next global seed, then run a continuation.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Why this exists: nn.Sequential lives in Type 2, so we can't directly return a model from IO. This helper keeps model construction pure while letting executable code avoid repeating the nextSeed/run pattern.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Autograd helpers (grad/vjp/jacobian) over TorchLean programs.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              This namespace is conceptually similar to PyTorch autograd + functorch/torch.func:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              PyTorch references:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              abbrev NN.API.autograd.model.Params {σ τ : Spec.Shape} (model : TorchLean.NN.Seq σ τ) (α : Type) :

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Parameter list type for a given model (a TList over Seq.paramShapes).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Loss function over a model output and a target.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                This is expressed in terms of RefTy so it works uniformly for eager execution and compiled execution.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  abbrev NN.API.autograd.model.linearParams {α : Type} {inDim outDim seedW seedB : } (w : Spec.Tensor α (Tensor.Shape.Mat outDim inDim)) (b : Spec.Tensor α (Tensor.Shape.Vec outDim)) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Params (TorchLean.Layers.linear inDim outDim seedW seedB) α

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Pack explicit weight and bias tensors for a single Layers.linear model.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Mean-squared error loss (mse) between yhat and y.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Cross-entropy loss between logits and one-hot targets. PyTorch analogue: nn.CrossEntropyLoss.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Detach the model output before feeding it into a loss.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        This is useful when you want to compute a metric loss without backpropagating through it.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Gradient of a model-loss w.r.t. the model parameters.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          This is the common training use case (PyTorch analogue: loss.backward() followed by parameter updates).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Gradient of the loss w.r.t. the inputs (x and target).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              def NN.API.autograd.model.gradX {σ τ υ : Spec.Shape} (model : TorchLean.NN.Seq σ τ) (loss : TorchLean.Autodiff.Model.OutputLoss τ υ) {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (params : TorchLean.Autodiff.Model.Params model α) (x : Spec.Tensor α σ) (target : Spec.Tensor α υ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              IO (Spec.Tensor α σ)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Convenience: gradient of the loss w.r.t. x.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Convenience: gradient of the loss w.r.t. the target argument.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  structure NN.API.autograd.model.ValueAndGrads {σ τ υ : Spec.Shape} (model : TorchLean.NN.Seq σ τ) (α : Type) :

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Forward+backward result for a scalar loss built from a model output.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  PyTorch comparison: this is the "compute loss + backward" payload, but with shapes tracked.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Run loss(model(params, x), target) and compute gradients w.r.t:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    • model parameters,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    • x,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    • target.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    This hides the CompiledScalar/argument-pack boilerplate for the common "one sample" case.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Return just (loss_value, grad_params).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        valueAndGradParams, but convert the 0-dim loss tensor to a scalar α.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Return (loss_value, grad_x).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Return (loss_value, grad_target).

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Vector-Jacobian product (VJP) w.r.t. model parameters.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              This is the "grad of outputs back into parameters" primitive. It is useful for custom losses or analysis tooling when you already have a seed tensor seedOut : τ.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                VJP w.r.t. the model input.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                This returns a one-element TList to match the general "inputs list" API shape. For the common case, use vjpInput to get the tensor directly.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  def NN.API.autograd.model.vjpInput {σ τ : Spec.Shape} (model : TorchLean.NN.Seq σ τ) {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (params : TorchLean.Autodiff.Model.Params model α) (x : Spec.Tensor α σ) (seedOut : Spec.Tensor α τ) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  IO (Spec.Tensor α σ)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Convenience wrapper: unwrap vjpInputs to return just dx.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Reverse-mode Jacobian (jacrev) of the model output w.r.t. parameters.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Returns an array of parameter-structured gradients: one entry per output coordinate. This mirrors the usual "jacrev returns a stack of per-output gradients" shape.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      def NN.API.autograd.model.jvpParams {σ τ υ : Spec.Shape} (model : TorchLean.NN.Seq σ τ) (loss : TorchLean.Autodiff.Model.OutputLoss τ υ) {α : Type} [Semantics.Scalar α] [DecidableEq Spec.Shape] (params : TorchLean.Autodiff.Model.Params model α) (x : Spec.Tensor α σ) (target : Spec.Tensor α υ) (vparams : TorchLean.Autodiff.Model.Params model α) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      IO α

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Jacobian-vector product (JVP) of a scalar loss w.r.t. parameters.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      This is the directional derivative in the direction vparams. Conceptually: d/dt loss(params + t*vparams, x, target) | t = 0.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Hessian-vector product (HVP) of a scalar loss w.r.t. parameters.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Returns a parameter-structured tensor list of the same shape as params.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          In PyTorch terms, this is the "functorch" style: differentiate plain functions, not modules.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Type of a pure tensor function expressed in RefTy form.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          This matches the calling convention expected by TorchLean.Program/autodiff compilation.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Forward-mode Jacobian (jacfwd) for a pure tensor function.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Hessian for a scalar-valued function.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Vector-Jacobian product (VJP) for a pure function.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Reverse-mode Jacobian (jacrev) of a pure tensor function.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Returns the Jacobian rows as an array of doutput/dinput tensors.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Gradient of a scalar-valued function w.r.t. its input.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Return (value, grad) for a scalar-valued function at x.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        valueAndGrad, but convert the 0-dim value tensor to a scalar α.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For