TorchLean API

NN.Runtime.Autograd.Compiled.GraphM

GraphM #

Proof-compiled graph authoring API.

Proofs.Autograd.Algebra.GraphData is an executable SSA/DAG graph used by the proof-compiled pipeline (Runtime.Autograd.Compiled). Constructing it directly exposes dependent indices (Idx) into the graph context and is therefore fairly low-level.

This module defines a small StateT builder (GraphM) that:

Reading map #

@[reducible, inline]

Shorthand for the underlying executable SSA graph type from Proofs.Autograd.Algebra.

Instances For
    @[reducible, inline]

    Executable node payload for the proof-compiled SSA graph (GraphData).

    Instances For

      A typed handle to a value in the growing compiled context.

      Var s carries its expected Shape at the type level, while id is the runtime index into the concatenated context Γ ++ ss.

      • id :

        Runtime id of the value inside the concatenated context Γ ++ ss.

        The shape index on Var s is the static guarantee; this numeric id is the executable handle used when constructing Idx proofs for GraphData nodes.

      Instances For
        @[implicit_reducible]

        GraphM.arg is correct but a little noisy for examples (you must repeat the index and shape).

        VarList + args give a small convenience layer: args returns one Var per entry in Γ, in order, without spelling indices.

        Dependent list of typed variables, aligned with a list of shapes.

        VarList Γ contains exactly one Var s for each s ∈ Γ, in order.

        Instances For

          First variable in a nonempty VarList.

          Instances For

            Tail variables in a nonempty VarList.

            Instances For
              @[reducible, inline]

              State for the GraphM builder.

              It is a sigma pair of:

              • the list of intermediate shapes ss produced so far, and
              • the corresponding executable SSA graph payload GraphData α Δ Γ ss.
              Instances For
                @[reducible, inline]

                Default GraphM state with no extra environment (Δ := Unit).

                Instances For
                  @[reducible, inline]

                  StateT builder monad for authoring a GraphData program, with explicit environment Δ.

                  Instances For
                    @[reducible, inline]

                    Default GraphM builder monad with Δ := Unit.

                    Instances For

                      Empty builder state (no intermediate nodes yet).

                      Instances For

                        Empty builder state for an explicit environment type Δ.

                        Instances For
                          def Runtime.Autograd.Compiled.GraphM.run {α : Type} {Γ : List Spec.Shape} {β : Type} (m : M α Γ β) :
                          Result (β × State α Γ)

                          Run a GraphM program from an empty state.

                          Instances For

                            Build a GraphData by running a GraphM program.

                            This is the usual entry point: write a do-block that constructs the graph using arg, ops, and returns Unit; get back the finalized builder state containing ss and the graph.

                            Instances For

                              Length of the current context Γ ++ ss (inputs + intermediates).

                              Instances For

                                Convert a Var s into a dependent Idx (Γ ++ ss) s.

                                This performs bounds checking and a runtime shape check, returning a structured error if the variable points outside the current context or has the wrong shape.

                                Instances For
                                  def Runtime.Autograd.Compiled.GraphM.push {α Δ : Type} {Γ ss : List Spec.Shape} {s : Spec.Shape} (g : PGraphData α Δ Γ ss) (node : PNodeData α Δ (Γ ++ ss) s) :
                                  MWith α Δ Γ (Var s)

                                  Append a node to the graph state and return a fresh Var pointing to its output.

                                  The returned variable id is Γ.length + ss.length, i.e. it points at the newly appended entry.

                                  Instances For

                                    Forward-mode JVP availability for a compiled graph builder op.

                                    • implemented : JvpAvailability

                                      The op supplies a real forward-mode JVP rule.

                                    • reverseOnly (op : String) : JvpAvailability

                                      The op supplies reverse-mode VJP only. Forward-mode requests fail loudly.

                                    Instances For

                                      Compiled ops that provide VJP for training but no forward-mode JVP rule.

                                      Keeping the list executable gives callers a stable preflight hook instead of discovering the gap only after a directional-derivative run reaches the node. The list is intentionally empty when all compiled builder ops have concrete JVP rules.

                                      Instances For

                                        Return the JVP status for a named compiled op.

                                        Instances For

                                          Human-readable message for reverse-only compiled ops.

                                          Instances For

                                            Fail-fast marker for compiled nodes whose forward-mode JVP rule is intentionally absent.

                                            Returning a zero tangent here would silently corrupt forward-mode autodiff. Reverse-mode users are unaffected because these nodes still provide real vjp implementations. Forward-mode callers get a loud error, and reverseOnlyJvpOps provides a preflight list for tools that want to reject such graphs before running a JVP.

                                            Instances For

                                              Reference an input variable from the initial context Γ.

                                              This checks that the provided index is within bounds and that the requested shape matches the shape at that position in Γ.

                                              PyTorch comparison: this is like naming a graph input tensor in a traced graph.

                                              Instances For

                                                Pure helper to build VarList Γ starting at a given id offset.

                                                Instances For

                                                  Return one Var per entry of Γ, in order.

                                                  This is a convenience wrapper around arg that avoids manually writing indices in examples.

                                                  Instances For
                                                    def Runtime.Autograd.Compiled.GraphM.const {α Δ : Type} [Zero α] {Γ : List Spec.Shape} {s : Spec.Shape} (t : Spec.Tensor α s) :
                                                    MWith α Δ Γ (Var s)

                                                    Embed a constant tensor as a node in the compiled graph.

                                                    This node has no input dependencies (vjp = 0, jvp = 0), i.e. it is treated as a constant with respect to the graph inputs.

                                                    PyTorch comparison: a constant literal captured into a traced/compiled graph.

                                                    Instances For
                                                      def Runtime.Autograd.Compiled.GraphM.randUniform {α : Type} [Context α] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (seed : ) :
                                                      MWith α Δ Γ (Var s)

                                                      Deterministic U[0,1) tensor generator (seeded, pure).

                                                      Instances For

                                                        Deterministic {0,1} mask generator (seeded, pure).

                                                        Note: for differentiation purposes, this node is treated as a stop-gradient op: jvp = 0 and vjp = 0 for all inputs (including keepProb). This matches the intended use in dropout where the probability is a hyperparameter (not differentiated), while keeping execution deterministic in the .compiled backend.

                                                        Instances For

                                                          Stop-gradient boundary.

                                                          Forward semantics: identity (detach(x) = x). Backward semantics: no gradient flows to x (treated as constant w.r.t. the graph inputs).

                                                          Instances For

                                                            JVP vs VJP in this module

                                                            Each compiled node stores both:

                                                            The .compiled runtime path is primarily exercised via reverse-mode (VJP) and compilation to the eager tape. Basic elementwise/bilinear ops provide real JVP rules, shape-structural ops (for example slice/concat) apply the same transformation to the tangent, and heavier ops should expose named spec-layer JVP helpers before being wired here. Reverse-only ops it must be listed in reverseOnlyJvpOps and call unsupportedJvp rather than returning a silent zero tangent.

                                                            Forward-mode coverage is expanded by adding concrete jvp rules next to the corresponding forward and vjp definitions.

                                                            def Runtime.Autograd.Compiled.GraphM.add {α Δ : Type} [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (a b : Var s) :
                                                            MWith α Δ Γ (Var s)

                                                            Elementwise addition node (y = a + b).

                                                            PyTorch comparison: torch.add(a, b).

                                                            Instances For
                                                              def Runtime.Autograd.Compiled.GraphM.sub {α Δ : Type} [Sub α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (a b : Var s) :
                                                              MWith α Δ Γ (Var s)

                                                              Elementwise subtraction node (y = a - b).

                                                              PyTorch comparison: torch.sub(a, b).

                                                              Instances For
                                                                def Runtime.Autograd.Compiled.GraphM.mul {α Δ : Type} [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (a b : Var s) :
                                                                MWith α Δ Γ (Var s)

                                                                Elementwise multiplication node (y = a ⊙ b).

                                                                PyTorch comparison: torch.mul(a, b).

                                                                Instances For
                                                                  def Runtime.Autograd.Compiled.GraphM.square {α Δ : Type} [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                  MWith α Δ Γ (Var s)

                                                                  Square x ↦ x ⊙ x.

                                                                  Instances For
                                                                    def Runtime.Autograd.Compiled.GraphM.scale {α Δ : Type} [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) (c : α) :
                                                                    MWith α Δ Γ (Var s)

                                                                    Scale a tensor by a scalar constant c (y = c * x).

                                                                    PyTorch comparison: c * x / torch.mul(x, c).

                                                                    Instances For
                                                                      def Runtime.Autograd.Compiled.GraphM.abs {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                      MWith α Δ Γ (Var s)

                                                                      Elementwise absolute value.

                                                                      PyTorch comparison: torch.abs(x).

                                                                      Instances For
                                                                        def Runtime.Autograd.Compiled.GraphM.sqrt {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                        MWith α Δ Γ (Var s)

                                                                        Elementwise square root.

                                                                        PyTorch comparison: torch.sqrt(x).

                                                                        Instances For
                                                                          def Runtime.Autograd.Compiled.GraphM.clamp {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) (minVal maxVal : α) :
                                                                          MWith α Δ Γ (Var s)

                                                                          Elementwise clamp to [minVal, maxVal].

                                                                          PyTorch comparison: torch.clamp(x, min=minVal, max=maxVal).

                                                                          Instances For
                                                                            def Runtime.Autograd.Compiled.GraphM.max {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (a b : Var s) :
                                                                            MWith α Δ Γ (Var s)

                                                                            Elementwise maximum.

                                                                            At ties we split the gradient equally (0.5 / 0.5), matching the tie-handling documented in the eager tape (NN.Runtime.Autograd.Engine.Core).

                                                                            PyTorch comparison: torch.maximum(a, b).

                                                                            Instances For
                                                                              def Runtime.Autograd.Compiled.GraphM.min {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] [DecidableRel fun (x1 x2 : α) => x1 > x2] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (a b : Var s) :
                                                                              MWith α Δ Γ (Var s)

                                                                              Elementwise minimum.

                                                                              At ties we split the gradient equally (0.5 / 0.5).

                                                                              PyTorch comparison: torch.minimum(a, b).

                                                                              Instances For
                                                                                def Runtime.Autograd.Compiled.GraphM.relu {α : Type} [Mul α] [Add α] [Zero α] [Max α] [One α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                MWith α Δ Γ (Var s)

                                                                                Elementwise ReLU.

                                                                                PyTorch comparison: torch.nn.functional.relu(x).

                                                                                Instances For
                                                                                  def Runtime.Autograd.Compiled.GraphM.sigmoid {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                  MWith α Δ Γ (Var s)

                                                                                  Elementwise sigmoid. PyTorch comparison: torch.sigmoid(x).

                                                                                  Instances For
                                                                                    def Runtime.Autograd.Compiled.GraphM.tanh {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                    MWith α Δ Γ (Var s)

                                                                                    Elementwise tanh. PyTorch comparison: torch.tanh(x).

                                                                                    Instances For
                                                                                      def Runtime.Autograd.Compiled.GraphM.softmax {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                      MWith α Δ Γ (Var s)

                                                                                      Softmax along the last axis (recursing over outer dimensions).

                                                                                      PyTorch comparison: torch.softmax(x, dim=-1).

                                                                                      Instances For
                                                                                        def Runtime.Autograd.Compiled.GraphM.logSoftmax {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                        MWith α Δ Γ (Var s)

                                                                                        Stable log-softmax along the last axis.

                                                                                        This is intentionally a primitive in the compiled graph, not the composition logsoftmax, so proof/IR execution and eager CUDA share the same PyTorch-style numerical contract.

                                                                                        Instances For
                                                                                          def Runtime.Autograd.Compiled.GraphM.softplus {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                          MWith α Δ Γ (Var s)

                                                                                          Elementwise softplus. PyTorch comparison: torch.nn.functional.softplus(x).

                                                                                          Instances For
                                                                                            def Runtime.Autograd.Compiled.GraphM.exp {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                            MWith α Δ Γ (Var s)

                                                                                            Elementwise exponential. PyTorch comparison: torch.exp(x).

                                                                                            Instances For
                                                                                              def Runtime.Autograd.Compiled.GraphM.log {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                              MWith α Δ Γ (Var s)

                                                                                              Elementwise natural logarithm. PyTorch comparison: torch.log(x).

                                                                                              Instances For
                                                                                                def Runtime.Autograd.Compiled.GraphM.inv {α : Type} [Context α] [Mul α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) :
                                                                                                MWith α Δ Γ (Var s)

                                                                                                Elementwise reciprocal x ↦ 1/x. PyTorch comparison: torch.reciprocal(x).

                                                                                                Instances For
                                                                                                  def Runtime.Autograd.Compiled.GraphM.safeLog {α : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (x : Var s) (ε : α := Numbers.epsilon) :
                                                                                                  MWith α Δ Γ (Var s)

                                                                                                  Elementwise numerically-stable log (uses an internal ε).

                                                                                                  PyTorch comparison: commonly written torch.log(x + eps).

                                                                                                  Instances For

                                                                                                    Reduce-sum over all entries, producing a scalar.

                                                                                                    PyTorch comparison: torch.sum(x).

                                                                                                    Instances For
                                                                                                      def Runtime.Autograd.Compiled.GraphM.mseLoss {α : Type} [Add α] [Sub α] [Mul α] [Div α] [Zero α] [One α] [Coe α] [DecidableEq Spec.Shape] {Δ : Type} {Γ : List Spec.Shape} {s : Spec.Shape} (yhat target : Var s) :

                                                                                                      Mean-squared error loss with "mean" reduction, producing a scalar.

                                                                                                      PyTorch comparison: torch.nn.functional.mse_loss(yhat, target, reduction=\"mean\").

                                                                                                      Instances For

                                                                                                        Affine layer y = W x + b in the compiled graph.

                                                                                                        PyTorch comparison: torch.nn.functional.linear / torch.nn.Linear.

                                                                                                        The JVP is the usual product rule: d(Wx+b) = dW*x + W*dx + db.

                                                                                                        Instances For

                                                                                                          Matrix multiplication ((m×n) @ (n×p) → (m×p)).

                                                                                                          PyTorch comparison: torch.matmul.

                                                                                                          The JVP is the bilinear product rule d(A @ B) = dA @ B + A @ dB.

                                                                                                          Instances For

                                                                                                            Batched matrix multiplication (batch×m×n with batch×n×p).

                                                                                                            PyTorch comparison: torch.bmm.

                                                                                                            The JVP is the batched bilinear product rule d(A @ B) = dA @ B + A @ dB.

                                                                                                            Instances For

                                                                                                              Concatenate two vectors (dim-0 concat).

                                                                                                              PyTorch comparison: torch.cat([a, b], dim=0) for 1D tensors.

                                                                                                              Instances For
                                                                                                                def Runtime.Autograd.Compiled.GraphM.concatDim0 {α Δ : Type} [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {n m : } {s : Spec.Shape} (a : Var (Spec.Shape.dim n s)) (b : Var (Spec.Shape.dim m s)) :
                                                                                                                MWith α Δ Γ (Var (Spec.Shape.dim (n + m) s))

                                                                                                                Concatenate along the leading dimension (dim=0) for tensors of shape .dim n s.

                                                                                                                PyTorch comparison: torch.cat([a, b], dim=0).

                                                                                                                Instances For
                                                                                                                  def Runtime.Autograd.Compiled.GraphM.sliceRange0 {α Δ : Type} [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {n : } {s : Spec.Shape} (x : Var (Spec.Shape.dim n s)) (start len : ) (h : len + start n) :
                                                                                                                  MWith α Δ Γ (Var (Spec.Shape.dim len s))

                                                                                                                  Slice a contiguous range along dim=0.

                                                                                                                  PyTorch comparison: x[start : start+len] for tensors where the leading dimension is indexed.

                                                                                                                  Instances For
                                                                                                                    def Runtime.Autograd.Compiled.GraphM.maxPool {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : Var (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                    MWith α Δ Γ (Var (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                    N-D max pooling (channels-first) on a single sample tensor (no batch axis).

                                                                                                                    PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the spatial rank d.

                                                                                                                    Forward-mode status: implemented. The JVP follows the primal argmax selected by Spec.maxPoolJvpSpec, including the documented first-winner tie convention.

                                                                                                                    Instances For
                                                                                                                      def Runtime.Autograd.Compiled.GraphM.avgPool {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d C : } {inSpatial kernel stride padding : Vector d} (hKernel : ∀ (i : Fin d), kernel.get i 0) (x : Var (Spec.Shape.ofList (C :: inSpatial.toList))) :
                                                                                                                      MWith α Δ Γ (Var (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                      N-D average pooling (channels-first) on a single sample tensor (no batch axis).

                                                                                                                      PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the spatial rank d.

                                                                                                                      Forward-mode status: implemented. Average pooling is linear, so the JVP is the same average-pool map applied to the input tangent.

                                                                                                                      Instances For
                                                                                                                        def Runtime.Autograd.Compiled.GraphM.smoothMaxPool {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d C : } {inSpatial kernel stride padding : Vector d} {hKernel : ∀ (i : Fin d), kernel.get i 0} (x : Var (Spec.Shape.ofList (C :: inSpatial.toList))) (beta : α) :
                                                                                                                        MWith α Δ Γ (Var (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))

                                                                                                                        N-D smooth max pooling (log-sum-exp surrogate) on a single sample tensor (no batch axis).

                                                                                                                        PyTorch comparison: there is no direct primitive; this is a differentiable approximation to max pooling.

                                                                                                                        Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.

                                                                                                                        Instances For
                                                                                                                          def Runtime.Autograd.Compiled.GraphM.maxPool2d {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                          MWith α Δ Γ (Var (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                          2D max-pooling (channel-first) on a single image tensor.

                                                                                                                          PyTorch comparison: torch.nn.functional.max_pool2d (without a batch dimension).

                                                                                                                          Forward-mode status: implemented. The JVP routes each output tangent through the argmax selected by the primal input.

                                                                                                                          Instances For
                                                                                                                            def Runtime.Autograd.Compiled.GraphM.maxPool2dPad {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (x : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                            MWith α Δ Γ (Var (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                            2D max-pooling with explicit padding.

                                                                                                                            PyTorch comparison: torch.nn.functional.max_pool2d with padding.

                                                                                                                            Forward-mode status: implemented. Padding is fixed and the JVP follows the real primal winner, ignoring padded cells just like the forward pass.

                                                                                                                            Instances For
                                                                                                                              def Runtime.Autograd.Compiled.GraphM.smoothMaxPool2d {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (x : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) (beta : α) :
                                                                                                                              MWith α Δ Γ (Var (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                              Smooth (soft) max-pooling, controlled by beta.

                                                                                                                              This is a differentiable approximation to max-pooling.

                                                                                                                              Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.

                                                                                                                              Instances For
                                                                                                                                def Runtime.Autograd.Compiled.GraphM.avgPool2d {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) (x : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                MWith α Δ Γ (Var (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                Average pooling (channel-first) on a single image tensor.

                                                                                                                                PyTorch comparison: torch.nn.functional.avg_pool2d (without a batch dimension).

                                                                                                                                Forward-mode status: implemented. Average pooling is linear, so the JVP is average pooling of the input tangent.

                                                                                                                                Instances For
                                                                                                                                  def Runtime.Autograd.Compiled.GraphM.avgPool2dPad {α Δ : Type} [Context α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (x : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                  MWith α Δ Γ (Var (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                  Average pooling with explicit padding.

                                                                                                                                  PyTorch comparison: torch.nn.functional.avg_pool2d with padding.

                                                                                                                                  Forward-mode status: implemented. Padding is fixed and average pooling is linear, so the JVP is the padded average-pool map applied to the input tangent.

                                                                                                                                  Instances For

                                                                                                                                    Flatten a tensor to a 1D vector (preserving total size).

                                                                                                                                    PyTorch comparison: torch.flatten(x) (for a single tensor value).

                                                                                                                                    Instances For
                                                                                                                                      def Runtime.Autograd.Compiled.GraphM.reshape {α Δ : Type} [Inhabited α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s₁ s₂ : Spec.Shape} (x : Var s₁) (h : s₁.size = s₂.size) :
                                                                                                                                      MWith α Δ Γ (Var s₂)

                                                                                                                                      Reshape a tensor, given a proof that the total sizes match.

                                                                                                                                      PyTorch comparison: torch.reshape(x, new_shape).

                                                                                                                                      Instances For

                                                                                                                                        Transpose a 2D matrix. PyTorch comparison: x.transpose(0, 1) / x.T for matrices.

                                                                                                                                        Instances For

                                                                                                                                          Transpose a rank-3 tensor by moving the first axis to the last ((a,b,c) → (b,c,a)).

                                                                                                                                          PyTorch comparison: x.permute(1, 2, 0).

                                                                                                                                          Instances For

                                                                                                                                            Transpose a rank-3 tensor by moving the last axis to the first ((a,b,c) → (c,a,b)).

                                                                                                                                            PyTorch comparison: x.permute(2, 0, 1).

                                                                                                                                            Instances For

                                                                                                                                              Swap the last two axes of a rank-3 tensor ((a,b,c) → (a,c,b)).

                                                                                                                                              PyTorch comparison: x.transpose(1, 2) for a 3D tensor.

                                                                                                                                              Instances For

                                                                                                                                                Swap two adjacent axes at a given nesting depth.

                                                                                                                                                This is the compiled-graph analogue of the eager Tape.swapAdjacentAtDepth. PyTorch comparison: a permute that swaps two neighboring dimensions.

                                                                                                                                                Instances For
                                                                                                                                                  def Runtime.Autograd.Compiled.GraphM.broadcastTo {α Δ : Type} [Inhabited α] [Add α] [Zero α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s₁ s₂ : Spec.Shape} (cb : s₁.CanBroadcastTo s₂) (x : Var s₁) :
                                                                                                                                                  MWith α Δ Γ (Var s₂)

                                                                                                                                                  Broadcast x : s₁ to a larger shape s₂ (given a CanBroadcastTo witness).

                                                                                                                                                  PyTorch comparison: x.expand(...) / broadcasting semantics in elementwise ops.

                                                                                                                                                  Instances For
                                                                                                                                                    def Runtime.Autograd.Compiled.GraphM.reduceSum {α Δ : Type} [Add α] [Zero α] [Inhabited α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {s : Spec.Shape} (axis : ) [valid : Spec.Shape.valid_axis_inst axis s] [wf : s.WellFormed] (x : Var s) :
                                                                                                                                                    MWith α Δ Γ (Var (Spec.Tensor.shapeAfterSum s axis))

                                                                                                                                                    Reduce-sum along a given axis.

                                                                                                                                                    PyTorch comparison: torch.sum(x, dim=axis).

                                                                                                                                                    Instances For

                                                                                                                                                      Reduce-mean along a given axis.

                                                                                                                                                      PyTorch comparison: torch.mean(x, dim=axis).

                                                                                                                                                      Instances For

                                                                                                                                                        Gather a single scalar from a vector at a known-in-bounds index.

                                                                                                                                                        PyTorch comparison: x[i] for a 1D tensor.

                                                                                                                                                        Instances For

                                                                                                                                                          Gather a row from a matrix at a known-in-bounds row index.

                                                                                                                                                          PyTorch comparison: x[i, :] for a 2D tensor.

                                                                                                                                                          Instances For

                                                                                                                                                            Gather a scalar from a vector at a runtime Nat index.

                                                                                                                                                            If i is out of bounds we return 0 and propagate no gradient (matching the forward choice).

                                                                                                                                                            Instances For

                                                                                                                                                              Gather a vector of length k from a length-n vector using an index tensor of Nats.

                                                                                                                                                              Out-of-bounds indices yield 0 at the corresponding output position.

                                                                                                                                                              PyTorch comparison: torch.gather for 1D inputs, with explicit bounds handling.

                                                                                                                                                              Instances For

                                                                                                                                                                Gather k rows from a (rows×cols) matrix using an index vector of Nats.

                                                                                                                                                                Out-of-bounds indices yield a zero row.

                                                                                                                                                                PyTorch comparison: torch.index_select(x, dim=0, index=idx) with explicit bounds handling.

                                                                                                                                                                Instances For

                                                                                                                                                                  Scatter-add into a vector at a single in-bounds index.

                                                                                                                                                                  scatter_add_vec x v i adds the scalar v into x[i].

                                                                                                                                                                  PyTorch comparison: x.index_add_(dim=0, index=[i], source=[v]) (conceptually).

                                                                                                                                                                  Instances For

                                                                                                                                                                    Scatter-add into a matrix at a single in-bounds row index.

                                                                                                                                                                    scatter_add_row x v i adds the row vector v into x[i, :].

                                                                                                                                                                    PyTorch comparison: x.index_add_(dim=0, index=[i], source=v.unsqueeze(0)) (conceptually).

                                                                                                                                                                    Instances For
                                                                                                                                                                      def Runtime.Autograd.Compiled.GraphM.layerNorm {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : Var (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar))) (gamma beta : Var (Spec.Shape.dim embedDim Spec.Shape.scalar)) :

                                                                                                                                                                      Layer normalization (sequence-first), producing the same shape as the input.

                                                                                                                                                                      PyTorch comparison: torch.nn.LayerNorm / torch.nn.functional.layer_norm (modulo exact layout).

                                                                                                                                                                      Forward-mode status: implemented by Spec.layerNormJvp, including parameter tangents for gamma and beta.

                                                                                                                                                                      Instances For
                                                                                                                                                                        def Runtime.Autograd.Compiled.GraphM.batchnormChannelFirst {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (x : Var (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar)))) (gamma beta : Var (Spec.Shape.dim channels Spec.Shape.scalar)) :

                                                                                                                                                                        Batch normalization in channel-first layout (no running statistics; spec-level functional form).

                                                                                                                                                                        PyTorch comparison: torch.nn.BatchNorm2d in NCHW layout (modulo exact semantics/parameters).

                                                                                                                                                                        Forward-mode status: implemented by Spec.batchNorm2dJvp, including parameter tangents for gamma and beta.

                                                                                                                                                                        Instances For
                                                                                                                                                                          def Runtime.Autograd.Compiled.GraphM.multiHeadAttention {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {n numHeads dModel headDim : } (h1 : n 0) (wq wk wv : Var (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar))) (wo : Var (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar))) (x : Var (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar))) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

                                                                                                                                                                          Multi-head attention primitive (shape-specialized).

                                                                                                                                                                          PyTorch comparison: torch.nn.MultiheadAttention / scaled dot-product attention.

                                                                                                                                                                          Forward-mode status: implemented by Spec.MultiHeadAttentionJvp, including tangents for the input and all four projection matrices.

                                                                                                                                                                          Instances For
                                                                                                                                                                            def Runtime.Autograd.Compiled.GraphM.conv {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : Var (Spec.Shape.ofList (outC :: inC :: kernel.toList))) (b : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (x : Var (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                            MWith α Δ Γ (Var (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                            N-D convolution (channels-first) on a single sample tensor (no batch axis).

                                                                                                                                                                            Conventions:

                                                                                                                                                                            • input shape is (inC, spatial...),
                                                                                                                                                                            • kernel shape is (outC, inC, kernelSpatial...),
                                                                                                                                                                            • bias shape is (outC),
                                                                                                                                                                            • output spatial sizes use the usual PyTorch-style formula (floor division).

                                                                                                                                                                            PyTorch comparison: torch.nn.functional.conv{d}d, specialized to a single sample.

                                                                                                                                                                            Forward-mode JVP uses bilinearity: d(conv(k,b,x)) = conv(k,0,dx) + conv(dk,db,x).

                                                                                                                                                                            Instances For
                                                                                                                                                                              def Runtime.Autograd.Compiled.GraphM.convTranspose {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : Var (Spec.Shape.ofList (inC :: outC :: kernel.toList))) (b : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (x : Var (Spec.Shape.ofList (inC :: inSpatial.toList))) :
                                                                                                                                                                              MWith α Δ Γ (Var (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))

                                                                                                                                                                              N-D transpose convolution (channels-first) on a single sample tensor (no batch axis).

                                                                                                                                                                              Conventions:

                                                                                                                                                                              • input shape is (inC, spatial...),
                                                                                                                                                                              • kernel shape is (inC, outC, kernelSpatial...) (PyTorch layout),
                                                                                                                                                                              • bias shape is (outC),
                                                                                                                                                                              • output spatial sizes use: out[a] = (in[a] - 1) * stride[a] - 2*padding[a] + kernel[a] (with output_padding = 0).

                                                                                                                                                                              PyTorch comparison: torch.nn.functional.conv_transpose{d}d, specialized to a single sample.

                                                                                                                                                                              Forward-mode JVP uses bilinearity: d(convTranspose(k,b,x)) = convTranspose(k,0,dx) + convTranspose(dk,db,x).

                                                                                                                                                                              Instances For
                                                                                                                                                                                def Runtime.Autograd.Compiled.GraphM.conv2d {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Var (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                MWith α Δ Γ (Var (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

                                                                                                                                                                                2D convolution (channel-first) on a single image tensor.

                                                                                                                                                                                PyTorch comparison: torch.nn.functional.conv2d (without a batch dimension).

                                                                                                                                                                                Forward-mode JVP uses bilinearity: d(conv2d(k,b,x)) = conv2d(k,0,dx) + conv2d(dk,db,x).

                                                                                                                                                                                Instances For
                                                                                                                                                                                  def Runtime.Autograd.Compiled.GraphM.convTranspose2d {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Var (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
                                                                                                                                                                                  MWith α Δ Γ (Var (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))

                                                                                                                                                                                  2D transpose convolution (channel-first) on a single image tensor.

                                                                                                                                                                                  PyTorch comparison: torch.nn.functional.conv_transpose2d (without a batch dimension).

                                                                                                                                                                                  Forward-mode JVP uses bilinearity: d(convTranspose2d(k,b,x)) = convTranspose2d(k,0,dx) + convTranspose2d(dk,db,x).

                                                                                                                                                                                  Instances For