TorchLean API

NN.Runtime.Autograd.Torch.Core

Torch Core #

PyTorch-style imperative front-end (eager runtime).

This wraps the eager runtime tape (Runtime.Autograd.Tape) behind an IO.Ref so user code can look closer to PyTorch:

This is purely a convenience layer; correctness/proof connections live elsewhere (e.g. Runtime.Autograd.Compiled / Proofs.Autograd.Algebra.Graph).

References:

@[reducible, inline]

TList is the dependently-typed heterogeneous tensor list used by the proved IR.

We re-export it here under the Torch front-end namespace because many user-facing helpers (trainer APIs, parameter packs, etc.) are naturally expressed as TLists.

Instances For

    Execution backend for the Torch-style front-end.

    • .eager: build and execute a runtime tape directly (imperative, PyTorch-like).
    • .compiled: record typed IR and run a compiled tape (proof-friendly path, see Torch.LinkedSession / TorchLean.Session).

    This is intentionally not a CUDA Graph selector. CUDA is controlled by Options.device on the eager backend; CUDA Graph capture/replay will require a distinct persistent-buffer backend.

    Instances For

      Execution device selector (PyTorch comparison: cpu vs cuda).

      Current scope:

      Instances For

        Options controlling the behavior of the Torch-style front-end.

        PyTorch comparison: these are roughly "session/global" settings (e.g. default requires_grad, and runtime-only performance toggles).

        • backend : Backend

          Execution backend selection.

        • requiresGradByDefault : Bool

          Default requires_grad value for newly created parameters/inputs when a caller omits it.

        • seed :

          Global deterministic seed for demo-style randomness.

          TorchLean keeps the semantic core pure and seed-threaded (JAX-style), so this is best understood as a convenient default seed knob that user code can thread into:

          • model initialization (per-layer init keys),
          • dataset shuffles / sampling,
          • and session-level RNG state (dropout, etc.).

          PyTorch analogue: torch.manual_seed(seed).

        • fastKernels : Bool

          Enable runtime-only fast kernels for a few hot ops in the eager backend.

          This is an execution/performance flag; it is not used by the proof-linked compilation path.

        • fastGpuMatmulPrecision : FastKernels.GpuMatmulPrecision

          GPU precision for fast-kernel matmul over Lean Float tensors.

          .fp32 matches the eager CUDA buffer stack. .fp64 selects the double-precision DGEMM path for matmul-only Float workloads that intentionally want double precision on GPU.

        • useGpu : Bool

          Eager execution on CUDA.

          When true and backend = .eager, the eager session uses the CUDA tape (Runtime.Autograd.Cuda.Tape) and Torch ops must route to CUDA implementations (no implicit CPU fallback).

          Compiled backend behavior is unchanged.

        • strictCuda : Bool

          Strict CUDA mode (eager backend only).

          Note: in CUDA eager mode TorchLean does not fall back to CPU per-op; missing CUDA ops always throw. This flag is retained for API compatibility.

        Instances For

          Read the device selector corresponding to useGpu.

          Instances For

            Set the device selector.

            Instances For

              Opaque handle to a tensor value in the current session/tape.

              This is the TorchLean analogue of a PyTorch Tensor object whose "identity" is a node/leaf id in the autograd tape. The phantom shape index s makes shape mismatches explicit at compile time.

              • id :

                Node/leaf identifier in the owning session tape.

              Instances For
                def Runtime.Autograd.Torch.instReprTensorRef.repr {α✝ : Type} {s✝ : Spec.Shape} [Repr α✝] :
                TensorRef α✝ s✝Std.Format
                Instances For
                  @[implicit_reducible]
                  instance Runtime.Autograd.Torch.instReprTensorRef {α✝ : Type} {s✝ : Spec.Shape} [Repr α✝] :
                  Repr (TensorRef α✝ s✝)

                  Handle to a Nat stored in the session's non-differentiable environment.

                  This is used to model index-like inputs (class labels, gather indices, etc.) which should not receive gradients.

                  • id :

                    Index into the session's non-differentiable Nat environment.

                  Instances For

                    Handle to a contiguous block of k Nats in the session's non-differentiable environment.

                    • start :

                      Start offset of the contiguous k-element block in the session's Nat environment.

                    Instances For
                      @[implicit_reducible]

                      Trainable parameter: a mutable tensor value plus metadata.

                      PyTorch comparison: analogous to torch.nn.Parameter, except the parameter becomes part of the autograd graph only when you use it in a particular session/tape.

                      • name : Option String

                        Optional user-facing name for logging/debugging.

                      • value : IO.Ref (Spec.Tensor α s)

                        Value at the current point.

                      • Optional CUDA-resident mirror of value.

                        The eager CUDA trainer uses this as a lightweight persistent-parameter cache: repeated forward passes can reuse the device buffer instead of uploading the host tensor every step. The host value remains the public source for CPU/proof-oriented APIs and is synchronized on explicit readback.

                      • hostCurrent : IO.Ref Bool

                        Whether value is known to match cudaValue.

                        CUDA optimizer steps mark this false after updating only the device mirror. Public parameter readback synchronizes and flips it back to true.

                      • requiresGrad : Bool

                        Whether this parameter receives accumulated gradients and optimizer updates.

                      Instances For

                        Type-erased parameter wrapper.

                        This exists so session code can store heterogeneous parameter shapes in a single HashMap keyed by leaf id (used for SGD updates).

                        • Runtime shape of the erased parameter.

                        • requiresGrad : Bool

                          Whether the underlying parameter receives optimizer updates.

                        • get : IO (AnyTensor α)

                          Read the current parameter value with its runtime shape.

                        • set : AnyTensor αIO Unit

                          Overwrite the current parameter value, checking shape at the call site.

                        • setCuda : Cuda.AnyBufferIO Unit

                          Store a CUDA buffer mirror without forcing an immediate host download.

                        Instances For

                          Release a cached CUDA mirror, if one exists.

                          CUDA buffers are external objects whose native finalizer tolerates repeated cleanup attempts, but long training loops should not wait for GC pressure before returning replaced parameter mirrors to the allocator.

                          Instances For

                            Package a typed Param α s as an AnyParam α, checking shape on set.

                            This is the bridge that allows generic optimizers/update routines to operate over heterogeneous parameter packs.

                            Instances For
                              @[reducible, inline]

                              Convenience: throw IO.userError on a .error result.

                              Instances For

                                Eager backend internals #

                                The eager backend for the backend-generic Ops interface needs a small tape-backed session to thread an IO.Ref to the runtime tape.

                                This is intentionally kept under Torch.Internal.*; the public session-style API is Runtime.Autograd.TorchLean.Session.

                                CUDA Bridge (Upload/Download) #

                                The CUDA eager tape stores float32 device buffers (Runtime.Autograd.Cuda.Buffer) paired with a runtime Shape (Runtime.Autograd.Cuda.AnyBuffer).

                                The Torch eager front-end still presents the spec-level Tensor α s API, so in CUDA mode we need:

                                The helper namespace gives CUDA bridge conversions stable call sites and a clear boundary.

                                Conversions required by the eager CUDA tape path.

                                Instances

                                  Float implementation #

                                  @[implicit_reducible]

                                  Float CUDA conversions: upload/download via row-major FloatArray.

                                  @[implicit_reducible, instance 10]

                                  Generic CPU-preserving fallback for scalar types without a CUDA wire-format bridge.

                                  Many TorchLean sessions are scalar-polymorphic on CPU, while the eager CUDA tape stores float32 buffers. The fallback keeps CPU execution available for proof-oriented scalar backends and fails loudly if a CUDA-only conversion is actually requested. Add a higher-priority TensorConv α instance for scalar types that have a deliberate float32 wire representation.

                                  Shape helpers for CUDA kernels #

                                  Runtime dimension list as an Array Nat (outermost-first).

                                  Instances For

                                    axisMap as an array.

                                    Instances For

                                      Synchronize a CUDA-updated parameter back to its host tensor, if needed.

                                      This is deliberately explicit. Training hot paths keep parameters resident on device; public readback APIs call this helper before exposing parameter tensors to the Lean side.

                                      Instances For

                                        Store/update the CUDA mirror of a parameter and mark the host tensor stale.

                                        Instances For

                                          Overwrite a host parameter value and invalidate any stale CUDA mirror.

                                          Instances For

                                            Internal eager session: a mutable runtime tape plus side tables.

                                            This is the state needed to offer a PyTorch-like API where "tensors" are opaque references and ops mutate a hidden tape stored in an IO.Ref.

                                            Notes:

                                            • tape stores values and backward closures (Runtime.Autograd.Tape).
                                            • paramsByLeaf remembers which tape leaf ids correspond to trainable parameters (for SGD).
                                            • nats stores non-differentiable Nat inputs used for indexing-like operations.
                                            Instances For

                                              Allocate a fresh eager session with an empty tape and empty side tables.

                                              Instances For

                                                Force-free a CUDA buffer allocation; the external finalizer is safe to call twice.

                                                Instances For

                                                  Force-release a shape-erased CUDA buffer.

                                                  Instances For

                                                    Release current CUDA tape values that are not persistent parameter mirrors.

                                                    Eager CUDA training creates many temporary buffers per step. Relying only on external-object finalizers can produce high transient memory pressure in long runs, so reset/step paths call this before discarding the current tape snapshot.

                                                    Instances For

                                                      Release CUDA tape values after an optimizer step.

                                                      Unlike releaseCudaTapeNonParamValues, this may release trainable parameter leaf buffers too. In a CUDA optimizer step, trainable parameters have already been written to fresh persistent mirrors, so the leaf buffers from the just-consumed tape are stale. Non-trainable parameter leaves still are their persistent mirrors, so we keep those cached.

                                                      Instances For

                                                        Release a dense CUDA gradient array after an optimizer has consumed it.

                                                        Instances For

                                                          Ask the native allocator to return/free pages after a large CUDA eager step.

                                                          Instances For

                                                            Reset the tape and side tables.

                                                            PyTorch comparison: like starting a fresh forward pass where the autograd graph is discarded.

                                                            Instances For
                                                              def Runtime.Autograd.Torch.Internal.EagerSession.param {α : Type} (s : EagerSession α) {sh : Spec.Shape} (init : Spec.Tensor α sh) (name : Option String := none) (requiresGrad : Option Bool := none) :
                                                              IO (Param α sh)

                                                              Create a mutable parameter object (not yet on the tape).

                                                              To record this parameter on the session tape, call use, which reads the parameter and records it as a leaf.

                                                              Instances For

                                                                Read back the concrete tensor value stored at a TensorRef.

                                                                This is a dynamic check: we ensure the id exists on the tape and the stored shape matches sh.

                                                                Instances For

                                                                  Record a constant leaf (non-differentiable) on the tape.

                                                                  PyTorch comparison: like constructing a tensor with requires_grad=False.

                                                                  Instances For

                                                                    Stop-gradient boundary.

                                                                    Forward semantics: identity (detach(x) = x). Backward semantics: no gradient flows to x.

                                                                    Instances For

                                                                      Deterministic U[0,1) tensor generator (seeded).

                                                                      Instances For

                                                                        Deterministic {0,1} mask generator (seeded) with a scalar keep-probability input.

                                                                        Instances For

                                                                          Use a parameter in the tape by recording its current value as a leaf.

                                                                          The returned TensorRef is the handle you pass to ops. The leaf id is stored in paramsByLeaf so optimizer steps (e.g. SGD) can update parameters after backward. PyTorch comparison: like using a torch.nn.Parameter in a forward pass (it becomes a leaf in the autograd graph).

                                                                          Instances For

                                                                            Record an external input tensor as a leaf on the tape.

                                                                            PyTorch comparison: like introducing a tensor into the autograd graph with a chosen requires_grad flag.

                                                                            Instances For

                                                                              Record a non-differentiable Nat input in the session environment.

                                                                              This supports ops that depend on indices/labels that should not receive gradients.

                                                                              Instances For

                                                                                Read a previously recorded NatRef.

                                                                                Instances For

                                                                                  Overwrite a previously recorded NatRef.

                                                                                  Instances For

                                                                                    Convert a Tensor Nat (.dim k .scalar) to an Array Nat.

                                                                                    Used to stage NatVecRef values into the session environment.

                                                                                    Instances For

                                                                                      Record a non-differentiable vector of Nats in the session environment.

                                                                                      Returns a NatVecRef k pointing to the stored block.

                                                                                      Instances For

                                                                                        Read back the vector stored at NatVecRef k.

                                                                                        Instances For

                                                                                          Overwrite the stored vector at NatVecRef k.

                                                                                          Instances For

                                                                                            Tensor ops (eager tape wrappers) #

                                                                                            The following definitions are thin wrappers around Runtime.Autograd.Tape.* primitives. Each one:

                                                                                            PyTorch comparison: this is the standard eager autograd mechanism (a dynamic tape of ops).

                                                                                            def Runtime.Autograd.Torch.Internal.EagerSession.dispatchCudaOpt {α β : Type} (s : EagerSession α) (opName : String) (cpu : IO β) (cuda : IO (Option β)) :
                                                                                            IO β

                                                                                            Dispatch an eager op with optional CUDA support.

                                                                                            When Options.device = .cuda, any op whose CUDA implementation returns none will throw.

                                                                                            TorchLean's CUDA eager mode is intentionally "no per-op CPU fallback": either the op is supported by CUDA, or it errors immediately.

                                                                                            Instances For

                                                                                              Record elementwise addition a + b. PyTorch: torch.add.

                                                                                              Instances For

                                                                                                Record elementwise subtraction a - b. PyTorch: torch.sub.

                                                                                                Instances For

                                                                                                  Record elementwise multiplication a * b. PyTorch: torch.mul.

                                                                                                  Instances For

                                                                                                    Record scaling by a scalar constant. PyTorch: x * c.

                                                                                                    Instances For
                                                                                                      def Runtime.Autograd.Torch.Internal.EagerSession.abs {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                                                                      IO (TensorRef α sh)

                                                                                                      Record elementwise absolute value. PyTorch: torch.abs.

                                                                                                      Instances For
                                                                                                        def Runtime.Autograd.Torch.Internal.EagerSession.sqrt {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                                                                        IO (TensorRef α sh)

                                                                                                        Record elementwise square root. PyTorch: torch.sqrt.

                                                                                                        Instances For
                                                                                                          def Runtime.Autograd.Torch.Internal.EagerSession.clamp {α : Type} [CudaBridge.TensorConv α] (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) (minVal maxVal : α) :
                                                                                                          IO (TensorRef α sh)

                                                                                                          Record elementwise clamp to [minVal,maxVal]. PyTorch: torch.clamp.

                                                                                                          Instances For
                                                                                                            def Runtime.Autograd.Torch.Internal.EagerSession.max {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (a b : TensorRef α sh) :
                                                                                                            IO (TensorRef α sh)

                                                                                                            Record elementwise maximum. PyTorch: torch.maximum.

                                                                                                            Instances For
                                                                                                              def Runtime.Autograd.Torch.Internal.EagerSession.min {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (a b : TensorRef α sh) :
                                                                                                              IO (TensorRef α sh)

                                                                                                              Record elementwise minimum. PyTorch: torch.minimum.

                                                                                                              Instances For
                                                                                                                def Runtime.Autograd.Torch.Internal.EagerSession.relu {α : Type} (s : EagerSession α) [Mul α] [Zero α] [Max α] [One α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                                                                                IO (TensorRef α sh)

                                                                                                                Record elementwise ReLU. PyTorch: torch.relu / torch.nn.functional.relu.

                                                                                                                Instances For

                                                                                                                  Record elementwise sigmoid. PyTorch: torch.sigmoid.

                                                                                                                  Instances For

                                                                                                                    Record elementwise tanh. PyTorch: torch.tanh.

                                                                                                                    Instances For

                                                                                                                      Record softmax (shape-preserving).

                                                                                                                      PyTorch comparison: torch.softmax(x, dim=...) (dimension convention is chosen by the underlying tape op).

                                                                                                                      Instances For

                                                                                                                        Record stable log-softmax (shape-preserving, last-axis convention).

                                                                                                                        PyTorch comparison: torch.nn.functional.log_softmax(x, dim=-1).

                                                                                                                        Instances For

                                                                                                                          Record elementwise softplus. PyTorch: torch.nn.functional.softplus.

                                                                                                                          Instances For

                                                                                                                            Record elementwise exponential. PyTorch: torch.exp.

                                                                                                                            Instances For

                                                                                                                              Record elementwise log. PyTorch: torch.log.

                                                                                                                              Instances For

                                                                                                                                Record elementwise inverse 1/x. PyTorch: torch.reciprocal.

                                                                                                                                Instances For

                                                                                                                                  Record elementwise log with epsilon guard.

                                                                                                                                  PyTorch comparison: torch.log(torch.clamp(x, min=ε)).

                                                                                                                                  Instances For

                                                                                                                                    Sum-reduce all elements to a scalar. PyTorch: x.sum().

                                                                                                                                    Instances For

                                                                                                                                      Flatten a tensor to a 1D vector. PyTorch: torch.flatten.

                                                                                                                                      Instances For

                                                                                                                                        Reshape a tensor while preserving total number of elements.

                                                                                                                                        PyTorch comparison: torch.reshape / view (when valid).

                                                                                                                                        Instances For

                                                                                                                                          Transpose a 2D matrix. PyTorch: x.t() / x.transpose(0,1).

                                                                                                                                          Instances For

                                                                                                                                            Swap two adjacent axes at a given depth. PyTorch analogue: x.transpose(dim, dim+1).

                                                                                                                                            Instances For

                                                                                                                                              Swap the last two axes of a 3D tensor (a,b,c) → (a,c,b). PyTorch: x.transpose(1,2).

                                                                                                                                              Instances For

                                                                                                                                                Broadcast a tensor to a larger shape. PyTorch: implicit broadcasting / expand.

                                                                                                                                                Instances For

                                                                                                                                                  Sum-reduce along axis. PyTorch: torch.sum(x, dim=axis).

                                                                                                                                                  Instances For

                                                                                                                                                    Mean-reduce along axis. PyTorch: torch.mean(x, dim=axis).

                                                                                                                                                    Instances For

                                                                                                                                                      Gather a scalar from a 1D vector with a Fin n index. PyTorch: x[i].

                                                                                                                                                      Instances For

                                                                                                                                                        Gather a row from a 2D tensor with a Fin rows index. PyTorch: x[i] for 2D tensors.

                                                                                                                                                        Instances For

                                                                                                                                                          Gather a scalar from a 1D vector with a raw Nat index (totalized by the tape op).

                                                                                                                                                          Instances For

                                                                                                                                                            Dynamic gather scalar using an index stored in NatRef.

                                                                                                                                                            Instances For

                                                                                                                                                              Dynamic gather row using an index stored in NatRef (out-of-range gives a zero row).

                                                                                                                                                              Instances For

                                                                                                                                                                Gather k scalars using an explicit index tensor. PyTorch analogue: gather / advanced indexing.

                                                                                                                                                                Instances For

                                                                                                                                                                  Gather k rows using an explicit index tensor. PyTorch: index_select(dim=0, index=...).

                                                                                                                                                                  Instances For

                                                                                                                                                                    Gather k scalars using indices stored in the nat-environment (NatVecRef).

                                                                                                                                                                    Instances For

                                                                                                                                                                      Gather k rows using indices stored in the nat-environment (NatVecRef).

                                                                                                                                                                      Instances For

                                                                                                                                                                        Scatter-add into a vector: return a copy of x with x[i] += v.

                                                                                                                                                                        Instances For

                                                                                                                                                                          Scatter-add into a matrix row: return a copy of x with x[i,:] += v.

                                                                                                                                                                          Instances For

                                                                                                                                                                            Fully-connected linear layer y = w x + b (matvec).

                                                                                                                                                                            If opts.fastKernels is enabled, uses a runtime-only fast kernel implementation. PyTorch comparison: torch.nn.functional.linear(x, weight=w, bias=b).

                                                                                                                                                                            Instances For

                                                                                                                                                                              Mean-squared-error loss returning a scalar.

                                                                                                                                                                              If opts.fastKernels is enabled, uses a runtime-only fast kernel implementation. PyTorch comparison: torch.nn.functional.mse_loss(..., reduction=\"mean\").

                                                                                                                                                                              Instances For
                                                                                                                                                                                def Runtime.Autograd.Torch.Internal.EagerSession.layerNorm {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : TensorRef α (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar))) (gamma beta : TensorRef α (Spec.Shape.dim embedDim Spec.Shape.scalar)) :

                                                                                                                                                                                Layer normalization over embedding dimension. PyTorch: nn.LayerNorm / functional.layer_norm.

                                                                                                                                                                                Instances For
                                                                                                                                                                                  def Runtime.Autograd.Torch.Internal.EagerSession.batchnormChannelFirst {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (x : TensorRef α (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar)))) (gamma beta : TensorRef α (Spec.Shape.dim channels Spec.Shape.scalar)) :

                                                                                                                                                                                  BatchNorm for channel-first images (C,H,W) (no batch axis). PyTorch: nn.BatchNorm2d (conceptually).

                                                                                                                                                                                  Instances For
                                                                                                                                                                                    def Runtime.Autograd.Torch.Internal.EagerSession.multiHeadAttention {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {n numHeads dModel headDim : } (h1 : n 0) (wq wk wv : TensorRef α (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar))) (wo : TensorRef α (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar))) (x : TensorRef α (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar))) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

                                                                                                                                                                                    Multi-head self-attention (typed, proof-friendly). PyTorch: nn.MultiheadAttention (conceptually).

                                                                                                                                                                                    Instances For
                                                                                                                                                                                      def Runtime.Autograd.Torch.Internal.EagerSession.conv {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : TensorRef α (Spec.Shape.ofList (outC :: inC :: kernel.toList))) (b : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (x : TensorRef α (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                                      IO (TensorRef α (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                                      N-D convolution for channels-first tensors (inC, spatial...) (no batch axis).

                                                                                                                                                                                      This is the generic counterpart to conv2d. PyTorch comparison: torch.nn.functional.conv{d}d specialized to a single sample.

                                                                                                                                                                                      Instances For
                                                                                                                                                                                        def Runtime.Autograd.Torch.Internal.EagerSession.conv2d {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (input : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                        IO (TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                        2D convolution for channel-first images (C,H,W) (no batch axis). PyTorch: torch.nn.functional.conv2d.

                                                                                                                                                                                        Instances For
                                                                                                                                                                                          def Runtime.Autograd.Torch.Internal.EagerSession.convTranspose {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : TensorRef α (Spec.Shape.ofList (inC :: outC :: kernel.toList))) (b : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (x : TensorRef α (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                                          IO (TensorRef α (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                                          N-D transpose convolution for channels-first tensors (inC, spatial...) (no batch axis).

                                                                                                                                                                                          This is the generic counterpart to conv_transpose2d. PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample.

                                                                                                                                                                                          Instances For
                                                                                                                                                                                            def Runtime.Autograd.Torch.Internal.EagerSession.convTranspose2d {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (input : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                            IO (TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))

                                                                                                                                                                                            2D transpose convolution for channel-first images (C,H,W) (no batch axis). PyTorch: torch.nn.functional.conv_transpose2d.

                                                                                                                                                                                            Instances For
                                                                                                                                                                                              @[reducible, inline]
                                                                                                                                                                                              abbrev Runtime.Autograd.Torch.Internal.EagerSession.conv2dCompat {α : Type} (s : EagerSession α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (input : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                              IO (TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                              Alias for conv2d (compat shorthand).

                                                                                                                                                                                              Instances For

                                                                                                                                                                                                2D matrix multiplication. PyTorch: torch.matmul for 2D tensors.

                                                                                                                                                                                                Instances For

                                                                                                                                                                                                  Batched matrix multiplication. PyTorch: torch.bmm.

                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                    Concatenate two vectors along dim 0. PyTorch: torch.cat([a,b], dim=0).

                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                      Concatenate along dim 0 for tensors with leading dimension. PyTorch: torch.cat(..., dim=0).

                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                        def Runtime.Autograd.Torch.Internal.EagerSession.sliceRange0 {α : Type} (s : EagerSession α) [Zero α] [DecidableEq Spec.Shape] {n : } {sh : Spec.Shape} (x : TensorRef α (Spec.Shape.dim n sh)) (start len : ) (h : len + start n) :

                                                                                                                                                                                                        Slice along dim 0: x[start:start+len]. PyTorch: standard slicing.

                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                          def Runtime.Autograd.Torch.Internal.EagerSession.maxPool {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                                                                                                          IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                          N-D max pooling for channels-first tensors (C, spatial...) (no batch axis).

                                                                                                                                                                                                          PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the spatial rank d.

                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                            def Runtime.Autograd.Torch.Internal.EagerSession.avgPool {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} (hKernel : ∀ (i : Fin d), kernel.get i 0) (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                                                                                                            IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                            N-D average pooling for channels-first tensors (C, spatial...) (no batch axis).

                                                                                                                                                                                                            PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the spatial rank d.

                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                              def Runtime.Autograd.Torch.Internal.EagerSession.smoothMaxPool {α : Type} [CudaBridge.TensorConv α] (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) (beta : α) :
                                                                                                                                                                                                              IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                              N-D smooth max pooling (log-sum-exp surrogate) for channels-first tensors (C, spatial...).

                                                                                                                                                                                                              This is a differentiable approximation to max pooling; PyTorch does not expose it as a single primitive, but it can be emulated with logsumexp over local windows.

                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                def Runtime.Autograd.Torch.Internal.EagerSession.maxPool2d {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                2D max-pooling (no batch axis). PyTorch: torch.nn.functional.max_pool2d.

                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                  def Runtime.Autograd.Torch.Internal.EagerSession.maxPool2dPad {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                  IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                  2D max-pooling with padding (no batch axis). PyTorch: max_pool2d(..., padding=...).

                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                    @[reducible, inline]
                                                                                                                                                                                                                    abbrev Runtime.Autograd.Torch.Internal.EagerSession.maxPoolPad {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                    IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                    Alias for max_pool2d_pad (PyTorch-style shorthand).

                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                      def Runtime.Autograd.Torch.Internal.EagerSession.smoothMaxPool2d {α : Type} [CudaBridge.TensorConv α] (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) (beta : α) :
                                                                                                                                                                                                                      IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                      Smooth max-pooling (softmax pooling). Not a standard PyTorch primitive; see Torch.LinkedSession.smooth_max_pool2d.

                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                        def Runtime.Autograd.Torch.Internal.EagerSession.avgPool2d {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                        IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                        2D average-pooling (no batch axis). PyTorch: torch.nn.functional.avg_pool2d.

                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                          def Runtime.Autograd.Torch.Internal.EagerSession.avgPool2dPad {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                          IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                          2D average-pooling with padding (no batch axis). PyTorch: avg_pool2d(..., padding=...).

                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                            @[reducible, inline]
                                                                                                                                                                                                                            abbrev Runtime.Autograd.Torch.Internal.EagerSession.avgPoolPad {α : Type} (s : EagerSession α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                            IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                            Alias for avg_pool2d_pad (PyTorch-style shorthand).

                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                              Run reverse-mode backprop on the CUDA tape, returning device gradients for all tape entries.

                                                                                                                                                                                                                              This is the CUDA analogue of backwardDenseAll, but it does not download gradients back to the host. This is primarily useful for implementing GPU-native optimizer steps.

                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                Convenience wrapper for scalar losses on CUDA: backward with seed 1 (device buffers).

                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                  Run reverse-mode backprop and return a dense gradient array for all tape entries.

                                                                                                                                                                                                                                  seed is the upstream gradient for out (like PyTorch's backward(gradient=...)).

                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                    Convenience wrapper for scalar losses: run backward with seed 1.

                                                                                                                                                                                                                                    PyTorch comparison: loss.backward() for a scalar loss.

                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                      Extract the gradient for a particular TensorRef from a dense gradient array.

                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                        Apply an SGD update to all parameters recorded via use.

                                                                                                                                                                                                                                        PyTorch comparison: for p in params: p.data -= lr * p.grad.

                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                          Apply an SGD update to all parameters recorded via use, using CUDA device gradients.

                                                                                                                                                                                                                                          This avoids downloading the full dense gradient array and keeps updated parameters in each Param's CUDA mirror. Host tensors are synchronized later by explicit parameter readback.

                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                            Device-side Adam moment buffers for one parameter leaf.

                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                              Apply an Adam update to all parameters recorded via use, using CUDA device gradients.

                                                                                                                                                                                                                                              This is the CUDA analogue of the generic TorchLean.Optim.adam path. It keeps Adam moments as device buffers and keeps updated parameters in each Param's CUDA mirror, so the next CUDA forward can reuse them without a host upload. Host tensors are synchronized later by explicit readback.

                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                Apply an AdamW update to all parameters recorded via use, using CUDA device gradients.

                                                                                                                                                                                                                                                This mirrors Optim.AdamW.update: moments are formed from the raw gradient, weight decay is applied directly to parameters, then the Adam update is applied. Like adamStepAllCuda, it keeps updated parameter buffers resident on device and only synchronizes the host copy when readback is requested.

                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                  Imperative sessions live in:

                                                                                                                                                                                                                                                  torch.compile-style wrapper (cached static graph).

                                                                                                                                                                                                                                                  This is a thin wrapper around the proof-compiled graph model (GraphData) and its proven-correct reverse-mode accumulator (GraphData.backpropCtx). It compiles once, then you can call forward/backward repeatedly with new inputs.

                                                                                                                                                                                                                                                  Note: this does not cache a Runtime.Autograd.Tape for reuse across different inputs. The current tape compiler bakes the forward context into backward closures, so reusing a single tape across changing inputs would be unsound without redesigning the runtime node API.

                                                                                                                                                                                                                                                  torch.compile-style wrapper for a scalar-valued computation over leaf context Γ.

                                                                                                                                                                                                                                                  This stores a proved node (NodeData) together with the preceding graph prefix so it can be evaluated and differentiated without rebuilding the whole graph each time.

                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                                                                    Convenience alias for the proved heterogeneous tensor list over a shape context.

                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                      Evaluate the scalar output for leaf values x.

                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                        Forward-mode Jacobian-vector product (JVP) at x with tangent dx.

                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.CompiledScalar.backward {α : Type} [Add α] [Zero α] [One α] {Γ : List Spec.Shape} (c : CompiledScalar α Γ) (x : TList α Γ) :
                                                                                                                                                                                                                                                          TList α Γ

                                                                                                                                                                                                                                                          Reverse-mode backprop for a scalar output with implicit seed 1.

                                                                                                                                                                                                                                                          Returns a TList of gradients aligned with the leaf context Γ.

                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.CompiledScalar.backwardWithSeed {α : Type} [Add α] [Zero α] {Γ : List Spec.Shape} (c : CompiledScalar α Γ) (x : TList α Γ) (seedOut : α) :
                                                                                                                                                                                                                                                            TList α Γ

                                                                                                                                                                                                                                                            Reverse-mode backprop for a scalar output with an explicit scalar seed.

                                                                                                                                                                                                                                                            PyTorch comparison: loss.backward(gradient=seedOut) for a scalar loss.

                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                              torch.compile-style wrapper for tensor-valued outputs.

                                                                                                                                                                                                                                                              This is the same idea as CompiledScalar, but parameterized by an arbitrary output shape τ. It supports:

                                                                                                                                                                                                                                                              torch.compile-style wrapper for a tensor-valued output of shape τ.

                                                                                                                                                                                                                                                              This generalizes CompiledScalar to arbitrary output shapes and provides forward-mode JVP and reverse-mode VJP (with explicit seed).

                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                @[reducible, inline]

                                                                                                                                                                                                                                                                Convenience alias for the proved heterogeneous tensor list over a shape context.

                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.CompiledOut.forward {α : Type} {Γ : List Spec.Shape} {τ : Spec.Shape} (c : CompiledOut α Γ τ) (x : TList α Γ) :

                                                                                                                                                                                                                                                                  Evaluate the output tensor for leaf values x.

                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.CompiledOut.jvp {α : Type} {Γ : List Spec.Shape} {τ : Spec.Shape} (c : CompiledOut α Γ τ) (x dx : TList α Γ) :

                                                                                                                                                                                                                                                                    Forward-mode Jacobian-vector product (JVP) at x with tangent dx.

                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.CompiledOut.vjpWithSeed {α : Type} [Add α] [Zero α] {Γ : List Spec.Shape} {τ : Spec.Shape} (c : CompiledOut α Γ τ) (x : TList α Γ) (seedOut : Spec.Tensor α τ) :
                                                                                                                                                                                                                                                                      TList α Γ

                                                                                                                                                                                                                                                                      Reverse-mode vector-Jacobian product (VJP) with an explicit output cotangent seed.

                                                                                                                                                                                                                                                                      This is the tensor-valued analogue of CompiledScalar.backwardWithSeed. PyTorch comparison: out.backward(gradient=seedOut) (for a tensor output).

                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                        Compile a scalar-output graph builder into a CompiledScalar.

                                                                                                                                                                                                                                                                        The builder is expressed in the Compiled.GraphM monad. We expect it to produce at least one node and return a variable of scalar shape.

                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                          Compile a tensor-output graph builder into a CompiledOut.

                                                                                                                                                                                                                                                                          We require that the returned Var τ is the last node produced by the builder, so the wrapper can store the prefix graph and final output node cleanly.

                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.Proofs.Autograd.Algebra.TList.append {α : Type} {ss₁ ss₂ : List Spec.Shape} :
                                                                                                                                                                                                                                                                            TList α ss₁TList α ss₂TList α (ss₁ ++ ss₂)

                                                                                                                                                                                                                                                                            Append two TLists.

                                                                                                                                                                                                                                                                            This is a small utility for bridging between curried APIs and list-of-shapes APIs.

                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.Proofs.Autograd.Algebra.TList.splitAppend {α : Type} {ss₁ ss₂ : List Spec.Shape} :
                                                                                                                                                                                                                                                                              TList α (ss₁ ++ ss₂)TList α ss₁ × TList α ss₂

                                                                                                                                                                                                                                                                              Split a TList α (ss₁ ++ ss₂) into its left and right parts.

                                                                                                                                                                                                                                                                              This is the inverse of TList.append.

                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                Type of a curried function accepting one tensor argument per shape in ss.

                                                                                                                                                                                                                                                                                For example, Fn α [s₁, s₂] β is Tensor α s₁ → Tensor α s₂ → β.

                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.Curried.curry {α β : Type} {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                  (TList α ssβ)Fn α ss β

                                                                                                                                                                                                                                                                                  Convert a function on TList inputs into its curried form.

                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.Curried.uncurry {α β : Type} {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                    Fn α ss βTList α ssβ

                                                                                                                                                                                                                                                                                    Convert a curried function into a function on TList inputs.

                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                      RefList is the reference-analogue of TList: a heterogeneous list of Ref s values indexed by a shape list.

                                                                                                                                                                                                                                                                                      This is used to write backend-generic code over references (e.g. TensorRefs in eager mode, or GraphM.Vars in compiled mode).

                                                                                                                                                                                                                                                                                      Reference-analogue of TList: a heterogeneous list of Ref s values indexed by shapes.

                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.RefList.append {Ref : Spec.ShapeType} {ss₁ ss₂ : List Spec.Shape} :
                                                                                                                                                                                                                                                                                        RefList Ref ss₁RefList Ref ss₂RefList Ref (ss₁ ++ ss₂)

                                                                                                                                                                                                                                                                                        Append two RefLists.

                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.RefList.split {Ref : Spec.ShapeType} {ss₁ ss₂ : List Spec.Shape} :
                                                                                                                                                                                                                                                                                          RefList Ref (ss₁ ++ ss₂)RefList Ref ss₁ × RefList Ref ss₂

                                                                                                                                                                                                                                                                                          Split a RefList Ref (ss₁ ++ ss₂) into its left and right parts.

                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.RefList.splitAppend1 {Ref : Spec.ShapeType} {ss : List Spec.Shape} {τ : Spec.Shape} :
                                                                                                                                                                                                                                                                                            RefList Ref (ss ++ [τ])RefList Ref ss × Ref τ

                                                                                                                                                                                                                                                                                            Split a RefList Ref (ss ++ [τ]) into its prefix and last element.

                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                              Type of a curried function over references, one Ref s argument per shape in ss.

                                                                                                                                                                                                                                                                                              This mirrors Curried.Fn, but for Ref-valued arguments (e.g. TensorRefs in eager mode or GraphM.Vars in compiled mode).

                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.CurriedRef.uncurry {Ref : Spec.ShapeType} {β : Type} {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                                CurriedRef Ref ss βRefList Ref ssβ

                                                                                                                                                                                                                                                                                                Uncurry a curried reference function to accept a RefList.

                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.CurriedRef.curry {Ref : Spec.ShapeType} {β : Type} {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                                  (RefList Ref ssβ)CurriedRef Ref ss β

                                                                                                                                                                                                                                                                                                  Curry a reference function that consumes a RefList.

                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                    Apply a curried reference function to a GraphM.VarList.

                                                                                                                                                                                                                                                                                                    This is a convenience for the compiled backend, where leaves/inputs are represented as Vars.

                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                      Backend-generic interface for building and executing tensor programs.

                                                                                                                                                                                                                                                                                                      This typeclass lets you write a single model/loss once (polymorphic over Ops m α) and then choose:

                                                                                                                                                                                                                                                                                                      • an eager backend that executes immediately on a runtime tape, or
                                                                                                                                                                                                                                                                                                      • a compiled backend that records proved IR (GraphM) for later compilation/proofs.

                                                                                                                                                                                                                                                                                                      Each method corresponds to a Tensor op; implementations are expected to match the semantics of the corresponding Runtime.Autograd.Tape.* / Compiled.GraphM.* operator.

                                                                                                                                                                                                                                                                                                      Instances
                                                                                                                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                                                                                                                        Reference type for the current Ops instance.

                                                                                                                                                                                                                                                                                                        In eager mode this will typically be TensorRef; in compiled mode it will typically be GraphM.Var.

                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.const {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (t : Spec.Tensor α s) :
                                                                                                                                                                                                                                                                                                          m (Ref s)

                                                                                                                                                                                                                                                                                                          Re-export of Ops.const. PyTorch: torch.tensor(...) / literal constants.

                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.add {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (a b : Ref s) :
                                                                                                                                                                                                                                                                                                            m (Ref s)

                                                                                                                                                                                                                                                                                                            Re-export of Ops.add. PyTorch: torch.add / +.

                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.sub {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (a b : Ref s) :
                                                                                                                                                                                                                                                                                                              m (Ref s)

                                                                                                                                                                                                                                                                                                              Re-export of Ops.sub. PyTorch: torch.sub / -.

                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.mul {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (a b : Ref s) :
                                                                                                                                                                                                                                                                                                                m (Ref s)

                                                                                                                                                                                                                                                                                                                Re-export of Ops.mul. PyTorch: torch.mul / *.

                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.scale {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) (c : α) :
                                                                                                                                                                                                                                                                                                                  m (Ref s)

                                                                                                                                                                                                                                                                                                                  Re-export of Ops.scale. PyTorch: x * c for a scalar c.

                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.abs {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                    m (Ref s)

                                                                                                                                                                                                                                                                                                                    Re-export of Ops.abs. PyTorch: torch.abs.

                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.sqrt {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                      m (Ref s)

                                                                                                                                                                                                                                                                                                                      Re-export of Ops.sqrt. PyTorch: torch.sqrt.

                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.clamp {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) (minVal maxVal : α) :
                                                                                                                                                                                                                                                                                                                        m (Ref s)

                                                                                                                                                                                                                                                                                                                        Re-export of Ops.clamp. PyTorch: torch.clamp.

                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.max {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (a b : Ref s) :
                                                                                                                                                                                                                                                                                                                          m (Ref s)

                                                                                                                                                                                                                                                                                                                          Re-export of Ops.max. PyTorch: torch.maximum.

                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.min {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (a b : Ref s) :
                                                                                                                                                                                                                                                                                                                            m (Ref s)

                                                                                                                                                                                                                                                                                                                            Re-export of Ops.min. PyTorch: torch.minimum.

                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.broadcastTo {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s₁ s₂ : Spec.Shape} (cb : s₁.CanBroadcastTo s₂) (x : Ref s₁) :
                                                                                                                                                                                                                                                                                                                              m (Ref s₂)

                                                                                                                                                                                                                                                                                                                              Re-export of Ops.broadcastTo. PyTorch: broadcasting / expand.

                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.reshape {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s₁ s₂ : Spec.Shape} (x : Ref s₁) (h : s₁.size = s₂.size) :
                                                                                                                                                                                                                                                                                                                                m (Ref s₂)

                                                                                                                                                                                                                                                                                                                                Re-export of Ops.reshape. PyTorch: reshape / view.

                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.transpose2d. PyTorch: x.t() / transpose.

                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.transpose3d_first_to_last. PyTorch: permute(1,2,0).

                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.transpose3d_last_to_first. PyTorch: permute(2,0,1).

                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.transpose3d_last_two. PyTorch: transpose(1,2).

                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.swapAdjacentAtDepth {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (depth : ) (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                          m (Ref (s.swapAdjacentAtDepth depth))

                                                                                                                                                                                                                                                                                                                                          Re-export of Ops.swapAdjacentAtDepth (general adjacent-axis swap).

                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.reduce_sum. PyTorch: torch.sum(..., dim=axis).

                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.reduce_mean. PyTorch: torch.mean(..., dim=axis).

                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.gather_scalar. PyTorch: x[i] (1D).

                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.gather_row. PyTorch: x[i] (2D row).

                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.gather_scalar_nat (index is a raw Nat).

                                                                                                                                                                                                                                                                                                                                                    Instances For

                                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.gather_vec_nat (index tensor).

                                                                                                                                                                                                                                                                                                                                                      Instances For

                                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.gather_rows_nat (index tensor).

                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                          Re-export of Ops.scatter_add_vec.

                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.scatter_add_row.

                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.matmul. PyTorch: torch.matmul for 2D tensors.

                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.bmm. PyTorch: torch.bmm.

                                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.concat_vectors. PyTorch: torch.cat([a,b], dim=0) for vectors.

                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.concatDim0 {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {nDim mDim : } {s : Spec.Shape} (a : Ref (Spec.Shape.dim nDim s)) (b : Ref (Spec.Shape.dim mDim s)) :
                                                                                                                                                                                                                                                                                                                                                                    m (Ref (Spec.Shape.dim (nDim + mDim) s))

                                                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.concat_dim0. PyTorch: torch.cat(..., dim=0).

                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.sliceRange0 {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {nDim : } {s : Spec.Shape} (start len : ) (h : len + start nDim) (x : Ref (Spec.Shape.dim nDim s)) :
                                                                                                                                                                                                                                                                                                                                                                      m (Ref (Spec.Shape.dim len s))

                                                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.slice_range0. PyTorch: x[start:start+len] on the leading dimension.

                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.maxPool {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : Ref (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                                                                                                                                                                                                                                                                        m (Ref (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.max_pool (generic N-D max pooling, channels-first; no batch axis).

                                                                                                                                                                                                                                                                                                                                                                        PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the spatial rank d.

                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.avgPool {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {d C : } {inSpatial kernel stride padding : Vector d} (hKernel : ∀ (i : Fin d), kernel.get i 0) (x : Ref (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                                                                                                                                                                                                                                                                          m (Ref (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                                                                                                                                                                                          Re-export of Ops.avg_pool (generic N-D average pooling, channels-first; no batch axis).

                                                                                                                                                                                                                                                                                                                                                                          PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the spatial rank d.

                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.smoothMaxPool {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : Ref (Spec.Shape.ofList (C :: inSpatial.toList))) (beta : α) :
                                                                                                                                                                                                                                                                                                                                                                            m (Ref (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.smooth_max_pool (generic N-D smooth max pooling, channels-first; no batch axis).

                                                                                                                                                                                                                                                                                                                                                                            This is a differentiable approximation to max pooling; PyTorch does not expose it as a single primitive.

                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.maxPool2d {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                              m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.max_pool2d. PyTorch: torch.nn.functional.max_pool2d.

                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.maxPool2dPad {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.max_pool2d_pad. PyTorch: max_pool2d(..., padding=...).

                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                  @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                  abbrev Runtime.Autograd.Torch.maxPoolPad {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                  m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                  Alias for max_pool2d_pad (PyTorch-style shorthand).

                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.smoothMaxPool2d {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) (beta : α) :
                                                                                                                                                                                                                                                                                                                                                                                    m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.smooth_max_pool2d (softmax pooling).

                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.avgPool2d {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                      m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.avg_pool2d. PyTorch: torch.nn.functional.avg_pool2d.

                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.avgPool2dPad {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                        m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.avg_pool2d_pad. PyTorch: avg_pool2d(..., padding=...).

                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                          @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                          abbrev Runtime.Autograd.Torch.avgPoolPad {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (x : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                          m (Ref (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                          Alias for avg_pool2d_pad (PyTorch-style shorthand).

                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.relu {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                            m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.relu.

                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.sigmoid {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                              m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.sigmoid.

                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.tanh {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.tanh.

                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.softmax {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                  m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.softmax.

                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.softplus {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                    m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.softplus.

                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.exp {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                      m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.exp.

                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.log {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                        m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.log.

                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.inv {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                          m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                          Re-export of Ops.inv (reciprocal).

                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.detach {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                            m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.detach. PyTorch: x.detach().

                                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.safeLog {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) (ε : α := Numbers.epsilon) :
                                                                                                                                                                                                                                                                                                                                                                                                              m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.safe_log.

                                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.randUniform {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (seed : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.rand_uniform (deterministic seeded RNG).

                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.bernoulliMask {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (keepProb : Ref Spec.Shape.scalar) (seed : ) :
                                                                                                                                                                                                                                                                                                                                                                                                                  m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.bernoulli_mask (deterministic dropout-style mask).

                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.logSoftmax {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (x : Ref s) (ε : α := Numbers.epsilon) :
                                                                                                                                                                                                                                                                                                                                                                                                                    m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                                    Stable log_softmax(x) along the last axis.

                                                                                                                                                                                                                                                                                                                                                                                                                    This is a backend primitive with the standard max-shifted formulation x - max(x) - log(sum(exp(x - max(x)))), matching PyTorch's numerical intent. The optional ε parameter is kept for source compatibility with existing TorchLean callers and is intentionally ignored; callers that need an epsilon-smoothed logarithm should use safeLog explicitly.

                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.silu {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Monad m] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                                      m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                                      SiLU / swish: x * sigmoid(x).

                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.gelu {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Monad m] [Ops m α] {s : Spec.Shape} (x : Ref s) :
                                                                                                                                                                                                                                                                                                                                                                                                                        m (Ref s)

                                                                                                                                                                                                                                                                                                                                                                                                                        GELU (approximation used by many ML frameworks):

                                                                                                                                                                                                                                                                                                                                                                                                                        0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3))).

                                                                                                                                                                                                                                                                                                                                                                                                                        This is defined using existing primitives (tanh, mul, add, scale), so it works in eager, compiled, and verifier-IR backends without introducing a new opcode.

                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.globalAvgPool2dChw {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Monad m] [Ops m α] {c h w : } (h_c_pos : c > 0) (h_h_pos : h > 0) (h_w_pos : w > 0) (x : Ref (Spec.Shape.dim c (Spec.Shape.dim h (Spec.Shape.dim w Spec.Shape.scalar)))) :

                                                                                                                                                                                                                                                                                                                                                                                                                          Global average pooling over the last two axes of a C×H×W tensor (channel-first).

                                                                                                                                                                                                                                                                                                                                                                                                                          Returns a vector C, averaging each channel over H×W.

                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.globalAvgPool2dNchw {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Monad m] [Ops m α] {n c h w : } (h_n_pos : n > 0) (h_c_pos : c > 0) (h_h_pos : h > 0) (h_w_pos : w > 0) (x : Ref (Spec.Shape.dim n (Spec.Shape.dim c (Spec.Shape.dim h (Spec.Shape.dim w Spec.Shape.scalar))))) :

                                                                                                                                                                                                                                                                                                                                                                                                                            Global average pooling over the last two axes of an N×C×H×W tensor (PyTorch default layout).

                                                                                                                                                                                                                                                                                                                                                                                                                            Returns N×C, averaging each channel over H×W for each batch element.

                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.sum. PyTorch: x.sum().

                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.flatten. PyTorch: torch.flatten.

                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.linear. PyTorch: torch.nn.functional.linear.

                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                    def Runtime.Autograd.Torch.mseLoss {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {s : Spec.Shape} (yhat target : Ref s) :

                                                                                                                                                                                                                                                                                                                                                                                                                                    Re-export of Ops.mse_loss. PyTorch: torch.nn.functional.mse_loss.

                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.layerNorm {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : Ref (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar))) (gamma beta : Ref (Spec.Shape.dim embedDim Spec.Shape.scalar)) :

                                                                                                                                                                                                                                                                                                                                                                                                                                      Re-export of Ops.layer_norm. PyTorch: nn.LayerNorm / functional.layer_norm.

                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.batchnormChannelFirst {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (x : Ref (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar)))) (gamma beta : Ref (Spec.Shape.dim channels Spec.Shape.scalar)) :

                                                                                                                                                                                                                                                                                                                                                                                                                                        Re-export of Ops.batchnorm_channel_first. PyTorch: nn.BatchNorm2d (conceptually).

                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                          def Runtime.Autograd.Torch.multiHeadAttention {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {n numHeads dModel headDim : } (h1 : n 0) (wq wk wv : Ref (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar))) (wo : Ref (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar))) (x : Ref (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar))) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

                                                                                                                                                                                                                                                                                                                                                                                                                                          Re-export of Ops.multi_head_attention.

                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                            def Runtime.Autograd.Torch.conv {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (weight : Ref (Spec.Shape.ofList (outC :: inC :: kernel.toList))) (bias : Ref (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Ref (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                            m (Ref (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                                                                                                                                                                                                                                                            Re-export of Ops.conv (generic N-D convolution, channels-first).

                                                                                                                                                                                                                                                                                                                                                                                                                                            PyTorch comparison: torch.nn.functional.conv{d}d specialized to a single sample (no batch axis).

                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                              def Runtime.Autograd.Torch.convTranspose {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (weight : Ref (Spec.Shape.ofList (inC :: outC :: kernel.toList))) (bias : Ref (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Ref (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                              m (Ref (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                                                                                                                                                                                                                                                                                              Re-export of Ops.conv_transpose (generic N-D transpose convolution, channels-first).

                                                                                                                                                                                                                                                                                                                                                                                                                                              PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample.

                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                def Runtime.Autograd.Torch.conv2d {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Ref (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Ref (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                m (Ref (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                                                                                Re-export of Ops.conv2d. PyTorch: torch.nn.functional.conv2d (conceptually, no batch axis).

                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.convTranspose2d {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Ref (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Ref (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                  m (Ref (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                                                                                  Re-export of Ops.conv_transpose2d. PyTorch: torch.nn.functional.conv_transpose2d.

                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                    @[reducible, inline]
                                                                                                                                                                                                                                                                                                                                                                                                                                                    abbrev Runtime.Autograd.Torch.conv2dCompat {m : TypeType} {α : Type} [Context α] [DecidableEq Spec.Shape] [Ops m α] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Ref (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Ref (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Ref (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                    m (Ref (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                                                                                                                                                                                                                                                                                    Alias for conv2d (compat shorthand).

                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                                                                                                                                                                                                                                                                      Monad used for the eager Ops instance: read an Internal.EagerSession α and execute in IO.

                                                                                                                                                                                                                                                                                                                                                                                                                                                      This is the backend that makes Ops programs execute immediately by mutating a hidden runtime tape.

                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                        @[implicit_reducible]

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Ops instance for the eager Torch-style runtime.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        This interprets Ops primitives by immediately executing them against the hidden mutable tape in the current Internal.EagerSession.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        @[implicit_reducible]

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Ops instance for the compiled graph-building monad GraphM.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        This interprets Ops primitives by recording typed IR nodes (rather than executing immediately). See Runtime.Autograd.Compiled.GraphM and Torch.LinkedSession for how these graphs are later run.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Heterogeneous list of trainable parameters, indexed by a list of shapes.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        This is the Torch front-end analogue of "a parameter vector" (like model.parameters() in PyTorch), but with shapes tracked at the type level.

                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                          Materialize the SGD update v - lr * g in a single traversal.

                                                                                                                                                                                                                                                                                                                                                                                                                                                          This is used by sgdStep_fast as a runtime-performance optimization to avoid building deep thunk chains when training for many steps.

                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                            Allocate a fresh ParamList from an initial TList of parameter tensors.

                                                                                                                                                                                                                                                                                                                                                                                                                                                            Each tensor becomes an IO.Ref so it can be updated by optimizer steps.

                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Allocate a fresh ParamList from an initial TList of parameter tensors, with explicit requiresGrad flags.

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Returns an error when the flag list length does not match the parameter shape list length.

                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                Read the current parameter values as a TList aligned with the shape list.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Read parameter values, synchronizing CUDA-resident mirrors first when necessary.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Overwrite the current parameter values from a TList aligned with the shape list.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                      def Runtime.Autograd.Torch.ParamList.sgdStep {α : Type} [Context α] {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                      ParamList α ss(lr : α) → TList α ssIO Unit

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Apply an SGD step p := p - lr * g to each parameter that has requiresGrad = true.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      gs must be aligned with the parameter shapes.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def Runtime.Autograd.Torch.ParamList.sgdStepFast {α : Type} [Context α] {ss : List Spec.Shape} :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                        ParamList α ss(lr : α) → TList α ssIO Unit

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Like sgdStep, but uses a fully materialized update (subScaleMaterialize) for speed.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        This is a runtime performance knob; mathematically it is equivalent to sgdStep.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                          structure Runtime.Autograd.Torch.ScalarTrainer (α : Type) (paramShapes inputShapes : List Spec.Shape) :

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Bundle a scalar-loss training loop for a fixed parameter pack and input signature.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          This is intended for simple demos:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • forward computes a scalar loss,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • backward computes gradients w.r.t. parameters,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • step applies an optimizer update (typically SGD),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • getParams reads current parameter values.
                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • params : ParamList α paramShapes

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Mutable trainable parameter pack.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • forward : Curried.Fn α inputShapes (IO (Spec.Tensor α Spec.Shape.scalar))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Compute the scalar loss for a curried input pack.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • backward : Curried.Fn α inputShapes (IO (TList α paramShapes))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Compute gradients aligned with paramShapes for a curried input pack.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • step : αCurried.Fn α inputShapes (IO Unit)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Apply one SGD-style update for a curried input pack.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • adamStep? : Option (ααααCurried.Fn α inputShapes (IO Unit))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Optional Adam update path.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            In eager CUDA mode this is a device-gradient/device-moment update path. Other backends expose none and should use the generic optimizer wrappers.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • adamWStep? : Option (αααααCurried.Fn α inputShapes (IO Unit))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Optional AdamW update path.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            In eager CUDA mode this is a device-gradient/device-moment update path with decoupled weight decay. Other backends expose none and should use the generic optimizer wrappers.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          • getParams : IO (TList α paramShapes)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Read current parameter values, synchronizing device mirrors if needed.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                          Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Extract gradients (as a typed TList) for a list of eager TensorRefs from a dense gradient array.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                            Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Record all parameters as tape leaves in an eager session, returning their corresponding TensorRefs.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                              This is the eager analogue of "using" a parameter pack during a forward pass.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                              Instances For

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Record all input tensors as tape leaves in an eager session, returning their corresponding TensorRefs.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                Instances For
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  def Runtime.Autograd.Torch.scalarTrainer {α : Type} [Context α] [Internal.CudaBridge.TensorConv α] [DecidableEq Spec.Shape] {paramShapes inputShapes : List Spec.Shape} (opts : Options := { }) (initRequiresGrad : List Bool := List.replicate paramShapes.length true) (loss : {m : TypeType} → [Monad m] → [inst : Ops m α] → CurriedRef (fun (s : Spec.Shape) => Ops.Ref m α s) (paramShapes ++ inputShapes) (m (Ops.Ref m α Spec.Shape.scalar))) :
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Curried.Fn α paramShapes (IO (ScalarTrainer α paramShapes inputShapes))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Build a ScalarTrainer from an initial parameter pack and a backend-generic loss definition.

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  loss is written once against the Ops interface over a concatenated context paramShapes ++ inputShapes. Depending on opts.backend, we either:

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  • compile the loss once (compiled backend), or
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  • execute it eagerly by building a runtime tape each step (eager backend).
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  Instances For