TorchLean API

NN.API.Public.Seeded.Core

Seeded model builders #

This module reopens NN.API.nn with the PyTorch-style seeded builder API. It sits on top of the pure builders from NN.API.Public.NN and allocates deterministic initialization seeds for users.

Model Builders and Seeding #

TorchLean keeps initialization randomness explicit so examples are reproducible.

Typical patterns:

  1. Explicit seeds (best for proofs / reproducibility-sensitive code):

    • build with nn.pure.linear ... (seedW := ...) (seedB := ...) etc
    • compose with seq! ... / >>>
  2. Script-style “manual seed once”:

Note: nn.Sequential lives in Type 2, so it cannot be returned directly from IO. We keep model building pure by drawing a base seed in IO and then calling nn.run.

PyTorch-like global seeding convenience for seeded model builders.

This sets the global seed stream used by nn.runGlobal / nn.nextSeed.

Instances For
    @[reducible, inline]

    Embedding table initialization configuration (one-hot / token-distribution inputs).

    This is the TorchLean-friendly analogue of torch.nn.Embedding in the common setting where token ids are represented as one-hot vectors (or soft token distributions), so lookup is a matrix multiplication rather than integer indexing.

    Instances For
      @[reducible, inline]

      Learned positional embedding configuration.

      This is a trainable parameter tensor of shape (seqLen × embedDim) that is broadcast across the leading batch dimension and added to the input.

      Instances For
        @[reducible, inline]

        Sinusoidal positional encoding configuration.

        This is the classic (non-trainable) Transformer sinusoidal encoding, added to token embeddings. startPos is an absolute-position offset (useful for KV-cache decoding).

        Instances For
          @[reducible, inline]

          Rotary positional embedding (RoPE) configuration.

          startPos is an absolute-position offset (useful for KV-cache decoding).

          Instances For
            @[reducible, inline]

            Named-field Conv2d configuration (CHW layout).

            This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

            Instances For
              @[reducible, inline]

              Named-field Conv2d configuration (CHW layout).

              This is the public, PyTorch-like entry point for convolution in TorchLean. PyTorch analogue: torch.nn.Conv2d. See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.

              Instances For
                @[reducible, inline]

                MaxPool2d configuration for CHW inputs.

                PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                Instances For
                  @[reducible, inline]

                  MaxPool2d configuration for CHW inputs.

                  PyTorch analogue: torch.nn.MaxPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.

                  Instances For
                    @[reducible, inline]

                    AvgPool2d configuration for CHW inputs.

                    PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                    Instances For
                      @[reducible, inline]

                      AvgPool2d configuration for CHW inputs.

                      PyTorch analogue: torch.nn.AvgPool2d. See https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html.

                      Instances For
                        @[reducible, inline]

                        LayerNorm configuration for batched (batch x seqLen x embedDim) tensors.

                        PyTorch analogue: torch.nn.LayerNorm. See https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.

                        Instances For
                          @[reducible, inline]

                          RMSNorm configuration for batched (batch x seqLen x embedDim) tensors.

                          This is a common alternative to LayerNorm in modern transformer architectures.

                          Instances For
                            @[reducible, inline]

                            BatchNorm2d configuration (learned scale/shift).

                            PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                            Instances For
                              @[reducible, inline]

                              InstanceNorm2d configuration (learned scale/shift).

                              PyTorch analogue: torch.nn.InstanceNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html.

                              Instances For
                                @[reducible, inline]

                                Multi-head self-attention configuration.

                                PyTorch analogue: torch.nn.MultiheadAttention (conceptually). See https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html.

                                Instances For
                                  def NN.API.nn.globalAvgPoolCHW (c h w : ) {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                  Global average pooling over a CHW tensor.

                                  PyTorch analogue: torch.nn.AdaptiveAvgPool2d((1, 1)) followed by flattening.

                                  Instances For
                                    def NN.API.nn.globalAvgPoolNCHW (n c h w : ) {hN : n > 0} {hC : c > 0} {hH : h > 0} {hW : w > 0} :

                                    Global average pooling over an NCHW tensor (preserves the batch dimension).

                                    Instances For

                                      Seeded Builders (Default nn.*) #

                                      For end-user code, the default nn.* layer constructors allocate initialization seeds automatically via nn.M (a deterministic seed-stream builder).

                                      Use nn.pure.* when you want to pass explicit seeds (proof-friendly / fully reproducible).

                                      @[reducible, inline]
                                      abbrev NN.API.nn.M (α : Type u_1) :
                                      Type u_1

                                      Seeded builder monad: a state monad over API.rand.SeedStream.

                                      Instances For
                                        def NN.API.nn.run {α : Type 2} (seed : ) (x : M α) :
                                        α

                                        Run a seeded builder starting from a base seed.

                                        Instances For
                                          def NN.API.nn.lift {α : Type 2} (x : α) :
                                          M α

                                          Lift a pure value into the seeded builder (consumes no seeds).

                                          Instances For
                                            def NN.API.nn.withSeed {α : Type 2} (k : α) :
                                            M α

                                            Consume one fresh seed and pass it to k.

                                            Instances For
                                              def NN.API.nn.withSeeds2 {α : Type 2} (k : α) :
                                              M α

                                              Consume two fresh seeds and pass them to k (in order).

                                              Instances For

                                                Elementwise ReLU. PyTorch analogue: torch.nn.ReLU / torch.nn.functional.relu.

                                                Instances For
                                                  @[reducible, inline]

                                                  Elementwise ReLU. PyTorch analogue: torch.nn.ReLU / torch.nn.functional.relu.

                                                  Instances For

                                                    Elementwise SiLU/Swish. PyTorch analogue: torch.nn.SiLU / torch.nn.functional.silu.

                                                    Instances For
                                                      @[reducible, inline]

                                                      Elementwise SiLU/Swish. PyTorch analogue: torch.nn.SiLU / torch.nn.functional.silu.

                                                      Instances For

                                                        Elementwise GELU. PyTorch analogue: torch.nn.GELU / torch.nn.functional.gelu.

                                                        Instances For
                                                          @[reducible, inline]

                                                          Elementwise GELU. PyTorch analogue: torch.nn.GELU / torch.nn.functional.gelu.

                                                          Instances For

                                                            Elementwise sigmoid. PyTorch analogue: torch.nn.Sigmoid / torch.nn.functional.sigmoid.

                                                            Instances For
                                                              @[reducible, inline]

                                                              Elementwise sigmoid. PyTorch analogue: torch.nn.Sigmoid / torch.nn.functional.sigmoid.

                                                              Instances For

                                                                Elementwise tanh. PyTorch analogue: torch.nn.Tanh / torch.nn.functional.tanh.

                                                                Instances For
                                                                  @[reducible, inline]

                                                                  Elementwise tanh. PyTorch analogue: torch.nn.Tanh / torch.nn.functional.tanh.

                                                                  Instances For

                                                                    Softmax. PyTorch analogue: torch.nn.Softmax / torch.nn.functional.softmax.

                                                                    Instances For
                                                                      @[reducible, inline]

                                                                      Softmax. PyTorch analogue: torch.nn.Softmax / torch.nn.functional.softmax.

                                                                      Instances For

                                                                        Reduce-sum to a scalar. PyTorch analogue: torch.sum.

                                                                        Instances For

                                                                          Flatten any tensor into a 1D vector of length size s. PyTorch analogue: torch.flatten.

                                                                          Instances For
                                                                            @[reducible, inline]

                                                                            Flatten any tensor into a 1D vector of length size s. PyTorch analogue: torch.flatten.

                                                                            Instances For

                                                                              Flatten a batched tensor N × σ into a matrix N × (size σ).

                                                                              PyTorch analogue: torch.flatten(x, start_dim=1).

                                                                              Instances For
                                                                                @[reducible, inline]

                                                                                Flatten a batched tensor N × σ into a matrix N × (size σ).

                                                                                PyTorch analogue: torch.flatten(x, start_dim=1).

                                                                                Instances For

                                                                                  Flatten a batched tensor starting at dimension 1 (keep dim0).

                                                                                  Synonym for flattenBatch, matching PyTorch’s start_dim=1 wording.

                                                                                  Instances For
                                                                                    def NN.API.nn.maxPool2d {n inC inH inW : } (cfg : MaxPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                    M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                    MaxPool2d using NeZero to hide nonzero kernel proofs.

                                                                                    Instances For
                                                                                      def NN.API.nn.maxPool {n inC inH inW : } (cfg : MaxPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                      M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                      Max pooling over batched CHW images, allocating any required initialization seeds automatically.

                                                                                      Shorthand for maxPool2d.

                                                                                      Instances For
                                                                                        def NN.API.nn.avgPool2d {n inC inH inW : } (cfg : AvgPool2d) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                        M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                        AvgPool2d over batched NCHW inputs (shape N×C×H×W, like PyTorch).

                                                                                        Instances For
                                                                                          def NN.API.nn.avgPool {n inC inH inW : } (cfg : AvgPool) [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                          M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n inC ((inH - cfg.kH) / cfg.stride + 1) ((inW - cfg.kW) / cfg.stride + 1)))

                                                                                          Average pooling over batched CHW images, allocating any required initialization seeds automatically.

                                                                                          Shorthand for avgPool2d.

                                                                                          Instances For
                                                                                            def NN.API.nn.linear (inDim outDim : ) (pfx : Spec.Shape := Spec.Shape.scalar) :
                                                                                            M (Sequential (pfx.appendDim inDim) (pfx.appendDim outDim))

                                                                                            Linear layer on the last axis (prefix-shape preserving).

                                                                                            PyTorch analogue: torch.nn.Linear. See https://pytorch.org/docs/stable/generated/torch.nn.Linear.html.

                                                                                            Unlike the runtime TorchLean layer constructor (which is vector-only), this public layer constructor follows PyTorch’s convention:

                                                                                            • if x has shape [..., inDim], linear inDim outDim returns a model of shape [..., outDim].

                                                                                            The leading “prefix” dimensions are treated as a batch (they are flattened to (numel(prefix), inDim), the affine map is applied once, and the result is reshaped back).

                                                                                            Instances For
                                                                                              @[reducible, inline]
                                                                                              abbrev NN.API.nn.Linear (inDim outDim : ) (pfx : Spec.Shape := Spec.Shape.scalar) :
                                                                                              M (Sequential (pfx.appendDim inDim) (pfx.appendDim outDim))

                                                                                              Linear layer on the last axis (prefix-shape preserving).

                                                                                              PyTorch analogue: torch.nn.Linear. See https://pytorch.org/docs/stable/generated/torch.nn.Linear.html.

                                                                                              Unlike the runtime TorchLean layer constructor (which is vector-only), this public layer constructor follows PyTorch’s convention:

                                                                                              • if x has shape [..., inDim], linear inDim outDim returns a model of shape [..., outDim].

                                                                                              The leading “prefix” dimensions are treated as a batch (they are flattened to (numel(prefix), inDim), the affine map is applied once, and the result is reshaped back).

                                                                                              Instances For
                                                                                                def NN.API.nn.linearV (inDim outDim : ) :

                                                                                                Vector-only linear layer, specialized to the scalar prefix shape.

                                                                                                Instances For
                                                                                                  @[reducible, inline]
                                                                                                  abbrev NN.API.nn.LinearV (inDim outDim : ) :

                                                                                                  Vector-only linear layer, specialized to the scalar prefix shape.

                                                                                                  Instances For
                                                                                                    def NN.API.nn.rnn (seqLen inputSize hiddenSize : ) :
                                                                                                    M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                    Vanilla RNN layer (time-major sequence, no batch axis).

                                                                                                    Semantics: h_t = tanh(W [x_t; h_{t-1}] + b), with h_{-1} = 0.

                                                                                                    This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                    PyTorch analogy: torch.nn.RNN(inputSize, hiddenSize, nonlinearity="tanh") with batch_first=false, specialized to a single batch element.

                                                                                                    Instances For
                                                                                                      def NN.API.nn.gru (seqLen inputSize hiddenSize : ) :
                                                                                                      M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                      GRU layer (time-major sequence, no batch axis).

                                                                                                      This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                      PyTorch analogy: torch.nn.GRU(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

                                                                                                      Instances For
                                                                                                        def NN.API.nn.mamba (seqLen inputSize hiddenSize : ) :
                                                                                                        M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                        Trainable Mamba-style gated diagonal state-space layer.

                                                                                                        The layer is time-major and single-batch, matching the simple rnn/gru/lstm constructors: input (seqLen × inputSize), output (seqLen × hiddenSize). It is unrolled with differentiable TorchLean ops, so CPU and CUDA training use the same API.

                                                                                                        Instances For
                                                                                                          def NN.API.nn.lstm (seqLen inputSize hiddenSize : ) :
                                                                                                          M (Sequential (Tensor.Shape.Mat seqLen inputSize) (Tensor.Shape.Mat seqLen hiddenSize))

                                                                                                          LSTM layer (time-major sequence, no batch axis).

                                                                                                          This is implemented by unrolling seqLen steps using existing TorchLean ops, so it runs on both CPU and CUDA backends.

                                                                                                          PyTorch analogy: torch.nn.LSTM(inputSize, hiddenSize) with batch_first=false, specialized to a single batch element.

                                                                                                          Instances For
                                                                                                            def NN.API.nn.conv2d {n inC inH inW : } (cfg : Conv2d) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                            M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1)))

                                                                                                            2D convolution over a batched image tensor (shape N×C×H×W, like PyTorch).

                                                                                                            Instances For
                                                                                                              def NN.API.nn.conv {n inC inH inW : } (cfg : Conv) [NeZero inC] [NeZero cfg.kH] [NeZero cfg.kW] :
                                                                                                              M (Sequential (Tensor.Shape.Images n inC inH inW) (Tensor.Shape.Images n cfg.outC ((inH + 2 * cfg.padding - cfg.kH) / cfg.stride + 1) ((inW + 2 * cfg.padding - cfg.kW) / cfg.stride + 1)))

                                                                                                              Convolution over batched CHW images, allocating initialization seeds automatically.

                                                                                                              Shorthand for conv2d.

                                                                                                              Instances For
                                                                                                                def NN.API.nn.batchNorm2d {n c h w : } (cfg : BatchNorm2d := { }) [NeZero n] [NeZero c] [NeZero h] [NeZero w] :

                                                                                                                BatchNorm2d over NCHW inputs, using NeZero to hide the positivity proofs.

                                                                                                                PyTorch analogue: torch.nn.BatchNorm2d. See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html.

                                                                                                                Instances For
                                                                                                                  def NN.API.nn.batchNorm {n c h w : } (cfg : BatchNorm2d := { }) [NeZero n] [NeZero c] [NeZero h] [NeZero w] :

                                                                                                                  BatchNorm over batched CHW images, allocating initialization seeds automatically.

                                                                                                                  Shorthand for batchNorm2d.

                                                                                                                  Instances For
                                                                                                                    def NN.API.nn.embedding (vocab embedDim : ) (cfg : Embedding := { }) {pfx : Spec.Shape} :
                                                                                                                    M (Sequential (pfx.appendDim vocab) (pfx.appendDim embedDim))

                                                                                                                    Embedding layer for one-hot / token-distribution inputs (no bias).

                                                                                                                    Input shape: [..., vocab] Output shape: [..., embedDim]

                                                                                                                    PyTorch analogue: conceptually nn.Embedding(vocab, embedDim) but applied to one-hot inputs.

                                                                                                                    Instances For
                                                                                                                      def NN.API.nn.sinusoidalPositionalEncoding {batch seqLen embedDim : } (cfg : SinusoidalPositionalEncoding := { }) :
                                                                                                                      M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                      Add sinusoidal positional encodings to a batched (batch × seqLen × embedDim) tensor.

                                                                                                                      Implementation:

                                                                                                                      • precompute PE : (seqLen × embedDim) at initialization time (stored as a non-trainable buffer),
                                                                                                                      • broadcast it across the leading batch axis and add to the input.
                                                                                                                      Instances For
                                                                                                                        def NN.API.nn.rope {batch numHeads seqLen headDim : } (cfg : RoPE := { }) :
                                                                                                                        M (Sequential (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim))) (Spec.Shape.dim batch (Spec.Shape.dim numHeads (Tensor.Shape.Mat seqLen headDim))))

                                                                                                                        Apply RoPE to a batched multi-head tensor (batch × numHeads × seqLen × headDim).

                                                                                                                        This matches the standard identity:

                                                                                                                        rope(x) = x * cos + rotatePairs(x) * sin

                                                                                                                        where cos/sin depend only on (pos, dim) and broadcast across (batch, numHeads).

                                                                                                                        Notes:

                                                                                                                        • This layer is differentiable (gradients flow through the rotation), but it has no trainable parameters; the precomputed cos/sin tables are stored as non-trainable buffers.
                                                                                                                        • The pure spec version is in NN.Spec.Layers.PositionalEncoding (Spec.rope_apply_heads_spec).
                                                                                                                        Instances For
                                                                                                                          def NN.API.nn.learnedPositionalEmbedding {batch seqLen embedDim : } (cfg : LearnedPositionalEmbedding := { }) :
                                                                                                                          M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                          Add learned positional embeddings to a batched (batch × seqLen × embedDim) tensor.

                                                                                                                          PyTorch analogue: x + pos[:seqLen] where pos is a parameter table.

                                                                                                                          Instances For
                                                                                                                            def NN.API.nn.layerNorm {batch seqLen embedDim : } [NeZero seqLen] [NeZero embedDim] :
                                                                                                                            M (Sequential (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)) (Spec.Shape.dim batch (Tensor.Shape.Mat seqLen embedDim)))

                                                                                                                            Layer normalization over (batch × seqLen × embedDim) tensors.

                                                                                                                            This normalizes each embedDim-vector (per batch element, per sequence position), and applies learned affine parameters gamma and beta.

                                                                                                                            PyTorch analogue: torch.nn.LayerNorm(embedDim) on a tensor shaped (batch, seqLen, embedDim).

                                                                                                                            Implementation note: TorchLean uses NeZero to ensure seqLen and embedDim are positive, avoiding degenerate shapes.

                                                                                                                            Instances For

                                                                                                                              Multi-head self-attention using NeZero to hide the nonzero sequence length proof.

                                                                                                                              If mask is provided, it is a boolean attention mask of shape (n × n) (e.g. causal masking).

                                                                                                                              Instances For

                                                                                                                                Transformer encoder block.

                                                                                                                                This is transformerEncoderBlockWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                Instances For

                                                                                                                                  Stack cfg.layers copies of blocks.transformerEncoderBlock.

                                                                                                                                  This is transformerEncoderStackWithMask; pass mask := ... to enable causal masking (or other attention masks).

                                                                                                                                  Instances For

                                                                                                                                    ResNet-18 style BasicBlock over batched image tensors (N×C×H×W).

                                                                                                                                    Instances For

                                                                                                                                      Dropout layer (active in train mode, identity in eval mode).

                                                                                                                                      PyTorch analogue: torch.nn.Dropout.

                                                                                                                                      Instances For
                                                                                                                                        def NN.API.nn.runGlobal {α : Type} (x : M α) :
                                                                                                                                        IO α

                                                                                                                                        Run a seeded builder using the global seed stream set by nn.manualSeed (results in Type).

                                                                                                                                        Note: model values like nn.Sequential live in Type 2, so they cannot be returned from IO. For models, use nn.run with an explicit base seed (obtained from nn.nextSeed).

                                                                                                                                        Instances For

                                                                                                                                          Draw a fresh base seed from the global seed stream set by nn.manualSeed.

                                                                                                                                          Instances For

                                                                                                                                            Draw n fresh base seeds from the global seed stream.

                                                                                                                                            Instances For

                                                                                                                                              Naming Convenience #

                                                                                                                                              nn.run / nn.nextSeed are the core primitives, but in user code it is often clearer to read:

                                                                                                                                              Draw a fresh base seed from the global seed stream.

                                                                                                                                              Instances For
                                                                                                                                                def NN.API.nn.withModel {σ τ : Spec.Shape} {β : Type} (mk : M (Sequential σ τ)) (k : Sequential σ τIO β) :
                                                                                                                                                IO β

                                                                                                                                                Build a model using the next global seed, then run a continuation.

                                                                                                                                                nn.Sequential lives in Type 2, so executable code passes the model to a continuation rather than returning it directly from IO.

                                                                                                                                                Instances For