TorchLean API

NN.Runtime.Autograd.Torch.LinkedSession

LinkedSession #

Proof-linked imperative Session (eager-style API, proved IR under the hood).

Background:

This file provides a session-style API that records a GraphData (well-typed IR) as you call ops imperatively, and then runs the standard runtime tape loop on the compiled tape.

Key guarantee (pure theorem, no IO reasoning needed):

Practical note:

@[reducible, inline]

Convenience: turn a Result α into IO α by throwing IO.userError on .error.

This mirrors the common pattern in the eager runtime front-end (Torch.Core).

Instances For
    @[reducible, inline]

    Non-differentiable external environment for the proved graph: a small array of Nat inputs.

    Instances For

      Internal proof-linked session state (a well-typed GraphData plus its leaf values).

      Instances For

        Empty session state: no leaves, no nodes, empty nat-environment.

        Instances For

          SessionIR is an imperative session that records a GraphData (proved IR) as it runs.

          It is "eager-style" (you call ops imperatively), but it is proof-linked: the recorded graph can be compiled and then the runtime tape backward loop is provably equal to GraphData.backpropAllCtx.

          Instances For

            Create a new proof-linked session.

            This allocates IO.Refs for the session snapshot (SessionIRState) and the leaf-id→parameter map. Call resetTape to start a new "graph recording" phase.

            Instances For

              Reset the session to an empty snapshot.

              Important invariant: this session requires that all leaves are created before any op node. resetTape is the intended boundary between training steps/forwards.

              Instances For
                def Runtime.Autograd.Torch.Internal.SessionIR.param {α : Type} (s : SessionIR α) {sh : Spec.Shape} (init : Spec.Tensor α sh) (name : Option String := none) (requiresGrad : Option Bool := none) :
                IO (Param α sh)

                Create a mutable parameter object (not yet part of the recorded graph).

                To use the parameter in the recorded graph, call use, which reads its current value and records it as a leaf in Γ. PyTorch comparison: analogous to creating a torch.nn.Parameter and then using it in a forward.

                Instances For

                  Enforce the session invariant: leaves must be created before any op node.

                  This keeps the GraphData context split Γ ++ ss easy to reason about and matches the typical training pattern: resetTapeadd leaves → forward ops → backward.

                  Instances For

                    Record a new differentiable leaf tensor in the session context Γ.

                    This is the primitive used by use (parameters) and input (external inputs).

                    Instances For

                      Use a Param in the recorded graph by reading its current value and recording it as a leaf.

                      The returned TensorRef is the graph handle you pass to subsequent ops. The session also remembers which leaf-id corresponds to which parameter, so sgdStepAll can update parameters after backward. PyTorch comparison: like referencing a torch.nn.Parameter in the forward; the parameter's value is treated as a leaf for autograd.

                      Instances For
                        def Runtime.Autograd.Torch.Internal.SessionIR.input {α : Type} (s : SessionIR α) {sh : Spec.Shape} [DecidableEq Spec.Shape] (v : Spec.Tensor α sh) (name : Option String := none) (requiresGrad : Bool := false) :
                        IO (TensorRef α sh)

                        Record an external differentiable input tensor as a leaf.

                        name and requiresGrad are accepted for API parity with the eager session, but this proof-linked session always records the input in Γ (a leaf) and uses typing/invariants to determine what gradients are meaningful.

                        Instances For

                          Record a non-differentiable Nat input in the external environment.

                          This is used for "index-like" inputs (labels, gather indices, etc.) that should not receive gradients. PyTorch comparison: like passing an integer tensor / index to an op; indices are not differentiable.

                          Instances For

                            Read a previously recorded NatRef.

                            Instances For

                              Overwrite a previously recorded NatRef.

                              Instances For

                                Convert a small Tensor Nat (.dim k .scalar) into an Array Nat.

                                This is used to stage NatVecRef inputs into the session nat-environment.

                                Instances For

                                  Record a non-differentiable vector of Nat inputs.

                                  Returns a NatVecRef k which points into the nat-environment. This is useful for "runtime gather" style ops where indices are supplied externally (and are not differentiable).

                                  Instances For

                                    Read back the k-vector stored at a NatVecRef k.

                                    Instances For

                                      Overwrite the nat-environment segment referenced by NatVecRef k.

                                      Instances For

                                        Build a typed index into the current context Γ ++ ss from a raw numeric id and expected shape.

                                        This is the main "dynamic check" used by getValue (and by a few index-driven nodes): it ensures that the Nat id points to an existing tensor in the session context and that the shape matches.

                                        Instances For

                                          Evaluate the recorded graph and return the value of a TensorRef.

                                          This is a pure graph evaluation (GraphData.eval) using the recorded leaf values and nat-environment. It does not run the runtime tape or mutate session state.

                                          Instances For

                                            Graph-node ops (implemented by reusing Compiled.GraphM) #

                                            Run a Compiled.GraphM computation against the current (ss, g) pair.

                                            Compiled.GraphM is the builder monad used by the proof-friendly compiled pipeline; reusing it here ensures this eager-style API records the same typed IR that the compiler expects.

                                            Instances For

                                              Atomically apply a graph-building update to the session snapshot.

                                              This is the central adapter used by each op wrapper below: it reads s.st, runs a builder that returns an updated SessionIRState, stores it back into s.st, and returns the op result.

                                              Instances For

                                                Record a constant tensor.

                                                Subtlety: if no op nodes have been created yet (ss = []), we record const as a leaf to match the eager session's leaf-collection behavior. Once op nodes exist, we emit an explicit constant node so users can introduce literal constants mid-graph. PyTorch comparison: like torch.tensor(...) (a leaf) vs inserting a literal constant into the graph; constants are treated as non-requires-grad.

                                                Instances For

                                                  Record elementwise addition a + b.

                                                  PyTorch comparison: torch.add(a, b) / the + operator.

                                                  Instances For

                                                    Record elementwise subtraction a - b.

                                                    PyTorch comparison: torch.sub(a, b) / the - operator.

                                                    Instances For

                                                      Record elementwise multiplication a * b.

                                                      PyTorch comparison: torch.mul(a, b) / the * operator.

                                                      Instances For
                                                        def Runtime.Autograd.Torch.Internal.SessionIR.scale {α : Type} (s : SessionIR α) [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) (c : α) :
                                                        IO (TensorRef α sh)

                                                        Record scaling by a scalar constant: x * c.

                                                        PyTorch comparison: like x * c (where c is a Python scalar).

                                                        Instances For
                                                          def Runtime.Autograd.Torch.Internal.SessionIR.abs {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                          IO (TensorRef α sh)

                                                          Record elementwise absolute value.

                                                          PyTorch comparison: torch.abs(x).

                                                          Instances For

                                                            Stop-gradient boundary.

                                                            Forward semantics: identity. Backward semantics: no gradient flows to the input. PyTorch comparison: x.detach().

                                                            Instances For
                                                              def Runtime.Autograd.Torch.Internal.SessionIR.sqrt {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                              IO (TensorRef α sh)

                                                              Record elementwise square root.

                                                              PyTorch comparison: torch.sqrt(x).

                                                              Instances For
                                                                def Runtime.Autograd.Torch.Internal.SessionIR.clamp {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) (minVal maxVal : α) :
                                                                IO (TensorRef α sh)

                                                                Record elementwise clamp to the interval [minVal, maxVal].

                                                                PyTorch comparison: torch.clamp(x, min=minVal, max=maxVal).

                                                                Instances For
                                                                  def Runtime.Autograd.Torch.Internal.SessionIR.max {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (a b : TensorRef α sh) :
                                                                  IO (TensorRef α sh)

                                                                  Record elementwise maximum of a and b.

                                                                  PyTorch comparison: torch.maximum(a, b).

                                                                  Instances For
                                                                    def Runtime.Autograd.Torch.Internal.SessionIR.min {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (a b : TensorRef α sh) :
                                                                    IO (TensorRef α sh)

                                                                    Record elementwise minimum of a and b.

                                                                    PyTorch comparison: torch.minimum(a, b).

                                                                    Instances For

                                                                      Record 2D matrix multiplication.

                                                                      PyTorch comparison: torch.matmul(a, b) for 2D tensors.

                                                                      Instances For

                                                                        Record batched matrix multiplication.

                                                                        PyTorch comparison: torch.bmm(a, b) for 3D tensors of shape (batch, m, n) and (batch, n, p).

                                                                        Instances For

                                                                          Concatenate two 1D vectors along dimension 0.

                                                                          PyTorch comparison: torch.cat([a, b], dim=0) for 1D tensors.

                                                                          Instances For

                                                                            Concatenate two tensors along dimension 0.

                                                                            PyTorch comparison: torch.cat([a, b], dim=0).

                                                                            Instances For
                                                                              def Runtime.Autograd.Torch.Internal.SessionIR.sliceRange0 {α : Type} (s : SessionIR α) [Zero α] [DecidableEq Spec.Shape] {n : } {sh : Spec.Shape} (x : TensorRef α (Spec.Shape.dim n sh)) (start len : ) (h : len + start n) :

                                                                              Slice a tensor along dimension 0.

                                                                              This returns x[start : start+len]. The proof argument h enforces bounds. PyTorch comparison: x[start:start+len] for tensors with a leading dimension.

                                                                              Instances For
                                                                                def Runtime.Autograd.Torch.Internal.SessionIR.maxPool {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                N-D max-pooling for channels-first tensors (C, spatial...) (no batch axis).

                                                                                PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the spatial rank d.

                                                                                Instances For
                                                                                  def Runtime.Autograd.Torch.Internal.SessionIR.smoothMaxPool {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) (beta : α) :
                                                                                  IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                  N-D smooth max-pooling (log-sum-exp surrogate) for channels-first tensors (C, spatial...).

                                                                                  This is a differentiable approximation of max-pooling; there is no direct PyTorch primitive.

                                                                                  Instances For
                                                                                    def Runtime.Autograd.Torch.Internal.SessionIR.avgPool {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {d C : } {inSpatial kernel stride padding : Vector d} (hKernel : ∀ (i : Fin d), kernel.get i 0) (x : TensorRef α (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                    IO (TensorRef α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                    N-D average-pooling for channels-first tensors (C, spatial...) (no batch axis).

                                                                                    PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the spatial rank d.

                                                                                    Instances For
                                                                                      def Runtime.Autograd.Torch.Internal.SessionIR.maxPool2d {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                      IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                      2D max-pooling for channel-first images.

                                                                                      PyTorch comparison: torch.nn.functional.max_pool2d (for NCHW-like layouts, here without batch).

                                                                                      Instances For
                                                                                        def Runtime.Autograd.Torch.Internal.SessionIR.smoothMaxPool2d {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) (beta : α) :
                                                                                        IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                        Smooth approximation of max-pooling (softmax pooling) for channel-first images.

                                                                                        This is not a standard PyTorch primitive; conceptually it behaves like applying a softmax over each pooling window with inverse-temperature beta and returning the expected value.

                                                                                        Instances For
                                                                                          def Runtime.Autograd.Torch.Internal.SessionIR.avgPool2d {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) (x : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                          IO (TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                          2D average-pooling for channel-first images.

                                                                                          PyTorch comparison: torch.nn.functional.avg_pool2d (for NCHW-like layouts, here without batch).

                                                                                          Instances For
                                                                                            def Runtime.Autograd.Torch.Internal.SessionIR.relu {α : Type} (s : SessionIR α) [Mul α] [Add α] [Zero α] [Max α] [One α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {sh : Spec.Shape} (x : TensorRef α sh) :
                                                                                            IO (TensorRef α sh)

                                                                                            Record elementwise ReLU.

                                                                                            PyTorch comparison: torch.relu(x) / torch.nn.functional.relu(x).

                                                                                            Instances For

                                                                                              Flatten a tensor into a 1D vector of length Shape.size sh.

                                                                                              PyTorch comparison: torch.flatten(x) (with default start_dim=0).

                                                                                              Instances For
                                                                                                def Runtime.Autograd.Torch.Internal.SessionIR.reshape {α : Type} (s : SessionIR α) [Inhabited α] [Zero α] [DecidableEq Spec.Shape] {sh1 sh2 : Spec.Shape} (x : TensorRef α sh1) (h : sh1.size = sh2.size) :
                                                                                                IO (TensorRef α sh2)

                                                                                                Reshape a tensor while preserving total number of elements.

                                                                                                The proof argument h enforces Shape.size sh1 = Shape.size sh2. PyTorch comparison: torch.reshape(x, new_shape) / x.view(new_shape) (when contiguous).

                                                                                                Instances For

                                                                                                  Transpose a 2D matrix (swap the two axes).

                                                                                                  PyTorch comparison: x.t() for 2D tensors, or x.transpose(0, 1).

                                                                                                  Instances For

                                                                                                    Permute a 3D tensor by moving the first axis to the end: (a,b,c) → (b,c,a).

                                                                                                    PyTorch comparison: x.permute(1,2,0) for a 3D tensor.

                                                                                                    Instances For

                                                                                                      Permute a 3D tensor by moving the last axis to the front: (a,b,c) → (c,a,b).

                                                                                                      PyTorch comparison: x.permute(2,0,1) for a 3D tensor.

                                                                                                      Instances For

                                                                                                        Swap the last two axes of a 3D tensor: (a,b,c) → (a,c,b).

                                                                                                        PyTorch comparison: x.transpose(1,2) for a 3D tensor.

                                                                                                        Instances For

                                                                                                          Swap two adjacent axes at a given depth inside the shape.

                                                                                                          This is a more general permutation helper used in some shape-manipulating models. PyTorch comparison: like x.transpose(dim, dim+1) for a suitably chosen dim.

                                                                                                          Instances For
                                                                                                            def Runtime.Autograd.Torch.Internal.SessionIR.broadcastTo {α : Type} (s : SessionIR α) [Inhabited α] [Add α] [Zero α] [DecidableEq Spec.Shape] {sh1 sh2 : Spec.Shape} (cb : sh1.CanBroadcastTo sh2) (x : TensorRef α sh1) :
                                                                                                            IO (TensorRef α sh2)

                                                                                                            Broadcast a tensor to a larger shape.

                                                                                                            The witness cb : Shape.CanBroadcastTo sh1 sh2 encodes the broadcasting compatibility proof. PyTorch comparison: x.expand(...) / implicit broadcasting.

                                                                                                            Instances For

                                                                                                              Sum-reduce along axis.

                                                                                                              PyTorch comparison: torch.sum(x, dim=axis).

                                                                                                              Instances For

                                                                                                                Mean-reduce along axis.

                                                                                                                PyTorch comparison: torch.mean(x, dim=axis).

                                                                                                                Instances For

                                                                                                                  Gather a single scalar x[i] from a 1D vector, with a compile-time Fin n index.

                                                                                                                  PyTorch comparison: x[i] for a 1D tensor.

                                                                                                                  Instances For

                                                                                                                    Gather a row x[i] from a 2D tensor, with a compile-time Fin rows index.

                                                                                                                    PyTorch comparison: x[i] for a 2D tensor (row indexing).

                                                                                                                    Instances For

                                                                                                                      Read a Nat from the nat-environment.

                                                                                                                      Out-of-bounds reads return 0 (total function), which is convenient for modeling "possibly invalid" indices without throwing.

                                                                                                                      Instances For

                                                                                                                        Read a length-k vector of Nats starting at start from the nat-environment.

                                                                                                                        Out-of-bounds reads fall back to 0 elementwise via natAt.

                                                                                                                        Instances For

                                                                                                                          Dynamic gather of a scalar from a 1D vector using a runtime NatRef index.

                                                                                                                          Out-of-range indices produce 0 instead of raising. PyTorch comparison: similar to x[i] where i is a Python integer, except PyTorch raises on out-of-range while this definition totalizes the behavior for ease of reasoning.

                                                                                                                          Instances For

                                                                                                                            Dynamic gather of a row from a 2D tensor using a runtime NatRef index.

                                                                                                                            Out-of-range indices yield a zero row. PyTorch comparison: similar to x[i] for 2D tensors with runtime i, but PyTorch raises on out-of-range whereas this definition is totalized for ease of reasoning.

                                                                                                                            Instances For

                                                                                                                              Dynamic gather of k scalars from a 1D tensor using a runtime NatVecRef k of indices.

                                                                                                                              Out-of-range indices yield 0. In the VJP, gradients are accumulated for repeated indices (i.e. it behaves like a gather followed by a scatter-add back into the source vector). PyTorch comparison: related to torch.gather / advanced indexing, but with totalized out-of-range behavior.

                                                                                                                              Instances For

                                                                                                                                Dynamic gather of k rows from a 2D tensor using a runtime NatVecRef k of row indices.

                                                                                                                                Out-of-range indices yield zero rows. In the VJP, gradients are accumulated into the selected rows (scatter-add semantics), including accumulation for repeated indices. PyTorch comparison: similar to torch.index_select(x, dim=0, index=...) or advanced indexing on the first dimension, but with totalized out-of-range behavior.

                                                                                                                                Instances For

                                                                                                                                  Gather a scalar from a 1D vector using a raw Nat index.

                                                                                                                                  PyTorch comparison: like x[i] with an integer index, but this operation is recorded into the proved IR (so it is stable for compilation/verification).

                                                                                                                                  Instances For

                                                                                                                                    Gather k scalars from a 1D vector using an explicit index tensor.

                                                                                                                                    PyTorch comparison: related to torch.gather / advanced indexing with an integer index tensor.

                                                                                                                                    Instances For

                                                                                                                                      Gather k rows from a 2D tensor using an explicit index tensor.

                                                                                                                                      PyTorch comparison: similar to torch.index_select(x, dim=0, index=...) or advanced indexing.

                                                                                                                                      Instances For

                                                                                                                                        Scatter-add into a vector: return a copy of x with x[i] += v.

                                                                                                                                        PyTorch comparison: similar to x.scatter_add_(dim=0, index=..., src=...) in spirit, but this is functional (returns a new tensor) and uses a single Fin n index.

                                                                                                                                        Instances For

                                                                                                                                          Scatter-add into a matrix row: return a copy of x with x[i, :] += v.

                                                                                                                                          PyTorch comparison: like adding a row vector into a selected row (functional analogue of an in-place indexed add).

                                                                                                                                          Instances For

                                                                                                                                            Record elementwise logistic sigmoid.

                                                                                                                                            PyTorch comparison: torch.sigmoid(x).

                                                                                                                                            Instances For

                                                                                                                                              Record elementwise hyperbolic tangent.

                                                                                                                                              PyTorch comparison: torch.tanh(x).

                                                                                                                                              Instances For

                                                                                                                                                Record softmax (shape-preserving).

                                                                                                                                                PyTorch comparison: torch.softmax(x, dim=...). This helper uses the convention baked into the underlying GraphM.softmax implementation.

                                                                                                                                                Instances For

                                                                                                                                                  Record stable log-softmax in the linked compiled session.

                                                                                                                                                  This commits a single GraphM.logSoftmax node instead of expanding to softmax followed by log, so compiled execution keeps the same stable semantics as eager CPU/CUDA.

                                                                                                                                                  Instances For

                                                                                                                                                    Record elementwise softplus.

                                                                                                                                                    PyTorch comparison: torch.nn.functional.softplus(x).

                                                                                                                                                    Instances For

                                                                                                                                                      Record elementwise exponential.

                                                                                                                                                      PyTorch comparison: torch.exp(x).

                                                                                                                                                      Instances For

                                                                                                                                                        Record elementwise natural logarithm.

                                                                                                                                                        PyTorch comparison: torch.log(x).

                                                                                                                                                        Instances For

                                                                                                                                                          Record elementwise log with epsilon guard.

                                                                                                                                                          This is intended for numerically stable losses; it corresponds roughly to log(max(x, ε)). PyTorch comparison: torch.log(torch.clamp(x, min=ε)).

                                                                                                                                                          Instances For

                                                                                                                                                            Sum-reduce all elements to a scalar.

                                                                                                                                                            PyTorch comparison: x.sum().

                                                                                                                                                            Instances For

                                                                                                                                                              Record a fully-connected linear layer: y = w • x + b.

                                                                                                                                                              Type-level shapes enforce w : (outDim, inDim), b : (outDim,), and x : (inDim,). PyTorch comparison: torch.nn.functional.linear(x, weight=w, bias=b) (with the same weight layout).

                                                                                                                                                              Instances For
                                                                                                                                                                def Runtime.Autograd.Torch.Internal.SessionIR.mseLoss {α : Type} (s : SessionIR α) [Add α] [Sub α] [Mul α] [Div α] [Zero α] [One α] [Coe α] [DecidableEq Spec.Shape] {sh : Spec.Shape} (yhat target : TensorRef α sh) :

                                                                                                                                                                Mean-squared-error loss returning a scalar.

                                                                                                                                                                PyTorch comparison: torch.nn.functional.mse_loss(yhat, target, reduction="mean").

                                                                                                                                                                Instances For
                                                                                                                                                                  def Runtime.Autograd.Torch.Internal.SessionIR.layerNorm {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : TensorRef α (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar))) (gamma beta : TensorRef α (Spec.Shape.dim embedDim Spec.Shape.scalar)) :

                                                                                                                                                                  Layer normalization over the trailing embedding dimension.

                                                                                                                                                                  This variant is specialized to 2D tensors of shape (seqLen, embedDim) and expects positive dimensions for numerical stability and well-formedness. PyTorch comparison: torch.nn.LayerNorm(embedDim) (applied per token), or torch.nn.functional.layer_norm.

                                                                                                                                                                  Instances For
                                                                                                                                                                    def Runtime.Autograd.Torch.Internal.SessionIR.batchnormChannelFirst {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (x : TensorRef α (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar)))) (gamma beta : TensorRef α (Spec.Shape.dim channels Spec.Shape.scalar)) :

                                                                                                                                                                    Batch normalization for a channel-first image (C,H,W) (no batch axis).

                                                                                                                                                                    gamma and beta are per-channel scale/shift parameters. PyTorch comparison: torch.nn.BatchNorm2d(C) (conceptually), or torch.nn.functional.batch_norm specialized to a single "batch element" with NCHW layout.

                                                                                                                                                                    Instances For
                                                                                                                                                                      def Runtime.Autograd.Torch.Internal.SessionIR.conv {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : TensorRef α (Spec.Shape.ofList (outC :: inC :: kernel.toList))) (b : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (x : TensorRef α (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                      IO (TensorRef α (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                      N-D convolution for channels-first tensors (inC, spatial...) (no batch axis).

                                                                                                                                                                      Kernel layout is (outC, inC, kernelSpatial...), bias is (outC).

                                                                                                                                                                      PyTorch comparison: torch.nn.functional.conv{d}d specialized to a single sample.

                                                                                                                                                                      Instances For
                                                                                                                                                                        def Runtime.Autograd.Torch.Internal.SessionIR.convTranspose {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : TensorRef α (Spec.Shape.ofList (inC :: outC :: kernel.toList))) (b : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (x : TensorRef α (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                        IO (TensorRef α (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                        N-D transpose convolution for channels-first tensors (inC, spatial...) (no batch axis).

                                                                                                                                                                        Kernel layout is (inC, outC, kernelSpatial...) (PyTorch convention), bias is (outC).

                                                                                                                                                                        PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample.

                                                                                                                                                                        Instances For
                                                                                                                                                                          def Runtime.Autograd.Torch.Internal.SessionIR.conv2d {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (input : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                          IO (TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                          2D convolution for channel-first images (inC, inH, inW) (no batch axis).

                                                                                                                                                                          Type-level shapes fix the kernel layout (outC, inC, kH, kW) and output spatial dimensions derived from stride and padding. PyTorch comparison: torch.nn.functional.conv2d (conceptually), specialized to a single image.

                                                                                                                                                                          Instances For
                                                                                                                                                                            def Runtime.Autograd.Torch.Internal.SessionIR.convTranspose2d {α : Type} (s : SessionIR α) [Context α] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : TensorRef α (Spec.Shape.dim outC Spec.Shape.scalar)) (input : TensorRef α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                            IO (TensorRef α (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))

                                                                                                                                                                            2D transpose convolution for channel-first images (inC, inH, inW) (no batch axis).

                                                                                                                                                                            Kernel layout matches the spec/PyTorch convention (inC, outC, kH, kW). PyTorch comparison: torch.nn.functional.conv_transpose2d specialized to a single image.

                                                                                                                                                                            Instances For
                                                                                                                                                                              def Runtime.Autograd.Torch.Internal.SessionIR.multiHeadAttention {α : Type} (s : SessionIR α) [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {n numHeads dModel headDim : } (h1 : n 0) (wq wk wv : TensorRef α (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar))) (wo : TensorRef α (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar))) (x : TensorRef α (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar))) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

                                                                                                                                                                              Multi-head self-attention.

                                                                                                                                                                              This is a shape-specialized attention primitive used by some demo transformer-style models:

                                                                                                                                                                              • input x has shape (n, dModel)
                                                                                                                                                                              • wq, wk, wv map dModel → numHeads*headDim
                                                                                                                                                                              • wo maps numHeads*headDim → dModel
                                                                                                                                                                              • optional mask is a boolean (n,n) attention mask

                                                                                                                                                                              PyTorch comparison: similar to torch.nn.MultiheadAttention / scaled dot-product attention, but encoded in a fully typed IR for compilation/proof linkage.

                                                                                                                                                                              Instances For

                                                                                                                                                                                Backward + SGD (runtime tape loop on the compiled tape) #

                                                                                                                                                                                Compile the recorded proved graph into a runtime tape.

                                                                                                                                                                                This uses Graph.compileAuxData (the same compiler used by the proof pipeline) and extracts the runtime tape component.

                                                                                                                                                                                Instances For

                                                                                                                                                                                  Run reverse-mode backprop for the whole recorded context and return a dense gradient array.

                                                                                                                                                                                  seed is the upstream gradient for out (same convention as PyTorch's loss.backward(gradient=...)).

                                                                                                                                                                                  Instances For

                                                                                                                                                                                    Convenience wrapper for scalar losses: run backward with seed 1.

                                                                                                                                                                                    PyTorch comparison: loss.backward() for a scalar loss.

                                                                                                                                                                                    Instances For

                                                                                                                                                                                      Extract the gradient tensor for a particular TensorRef from a dense gradient array.

                                                                                                                                                                                      This is the typed analogue of looking up grads[x.id] and casting it to the expected shape.

                                                                                                                                                                                      Instances For

                                                                                                                                                                                        Forward-mode: JVP (compiled only) #

                                                                                                                                                                                        Like mkIdxOrThrow, but restricted to leaves Γ only.

                                                                                                                                                                                        Instances For

                                                                                                                                                                                          Convert a dense tangent array (aligned with leaf creation order) into a typed TList α Γ.

                                                                                                                                                                                          This is the main adapter needed to call the proved GraphData.jvpCtx forward-mode routine.

                                                                                                                                                                                          Instances For

                                                                                                                                                                                            Jacobian-vector product for the current session snapshot.

                                                                                                                                                                                            dxs is a dense array of tangents for leaf tensors, aligned with leaf creation order.

                                                                                                                                                                                            Instances For
                                                                                                                                                                                              def Runtime.Autograd.Torch.Internal.SessionIR.jvpLeaf {α : Type} (s : SessionIR α) [Zero α] [DecidableEq Spec.Shape] {shOut shX : Spec.Shape} (out : TensorRef α shOut) (x : TensorRef α shX) (dx : Spec.Tensor α shX) :
                                                                                                                                                                                              IO (Spec.Tensor α shOut)

                                                                                                                                                                                              JVP for a single leaf: tangent is nonzero only at x.

                                                                                                                                                                                              Instances For

                                                                                                                                                                                                Scalar-loss JVP for a single leaf.

                                                                                                                                                                                                Instances For
                                                                                                                                                                                                  def Runtime.Autograd.Torch.Internal.SessionIR.sgdStepAll {α : Type} (s : SessionIR α) [Sub α] [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] (lr : α) (grads : Array (AnyTensor α)) :

                                                                                                                                                                                                  Apply an SGD update to all parameters recorded via use.

                                                                                                                                                                                                  grads is expected to be the dense gradient array returned by backwardDenseAll / backwardScalarDenseAll. Only entries corresponding to parameters (leaves that were produced by use) are used to update Param.value. PyTorch comparison: like iterating params and doing p.data -= lr * p.grad.

                                                                                                                                                                                                  Instances For

                                                                                                                                                                                                    Pure correctness hook: session snapshot ↔ proved IR backprop #

                                                                                                                                                                                                    Core proof-link: running the runtime reverse-mode loop on the compiled tape equals proved backprop.

                                                                                                                                                                                                    This theorem is the "hook" that lets a session-style API be backed by the proved IR: compileAuxData produces a tape, and Tape.backwardDenseFrom is shown equal to GraphData.backpropAllCtx (up to the TList.toAnyArray representation change).

                                                                                                                                                                                                    Public re-exports (stable names for docs) #

                                                                                                                                                                                                    @[reducible, inline]

                                                                                                                                                                                                    Public alias for the proof-linked session state (internal definition re-export).

                                                                                                                                                                                                    Instances For
                                                                                                                                                                                                      @[reducible, inline]

                                                                                                                                                                                                      Empty SessionIRState (no parameters/graph recorded yet).

                                                                                                                                                                                                      Instances For
                                                                                                                                                                                                        @[reducible, inline]

                                                                                                                                                                                                        Public alias for the proof-linked session object (internal definition re-export).

                                                                                                                                                                                                        Instances For
                                                                                                                                                                                                          @[reducible, inline]

                                                                                                                                                                                                          Create a new proof-linked session (records a graph + supports proved backprop hook).

                                                                                                                                                                                                          Instances For
                                                                                                                                                                                                            @[reducible, inline]
                                                                                                                                                                                                            abbrev Runtime.Autograd.Torch.SessionIR.backwardDenseAll {α : Type} (s : SessionIR α) [Add α] [Zero α] [DecidableEq Spec.Shape] {sh : Spec.Shape} (out : TensorRef α sh) (seed : Spec.Tensor α sh) :

                                                                                                                                                                                                            Compute dense gradients for all tracked refs w.r.t. an output tensor and a seed.

                                                                                                                                                                                                            This mirrors the "backward with custom seed" pattern in tensor AD systems.

                                                                                                                                                                                                            Instances For
                                                                                                                                                                                                              @[reducible, inline]

                                                                                                                                                                                                              Dense gradients for all tracked refs w.r.t. a scalar loss (seed is implicitly 1).

                                                                                                                                                                                                              Instances For
                                                                                                                                                                                                                @[reducible, inline]

                                                                                                                                                                                                                Extract the gradient tensor for a specific ref from a dense gradient array.

                                                                                                                                                                                                                Instances For

                                                                                                                                                                                                                  Public proof hook: the runtime reverse-mode loop on the compiled tape equals proved IR backprop.

                                                                                                                                                                                                                  This is a re-export of the internal theorem so downstream users can cite a stable name.