TorchLean API

NN.Runtime.Autograd.Compiled.IRExec

IRExec #

IR → executable SSA graph bridge.

This module lets us run an op-tagged NN.IR.Graph by compiling it into an executable Proofs.Autograd.Algebra.GraphData (the SSA/DAG representation used by the proof-compiled runtime).

Why this exists:

Important:

PyTorch intuition #

If you’re coming from PyTorch:

Reading map #

Main definitions #

Implementation notes #

References #

Tags #

ir, compiler, runtime, graphdata, forward-semantics

@[simp]
theorem Runtime.Autograd.Compiled.Except.ok_bind {ε α β : Type} (a : α) (f : αExcept ε β) :
Except.ok a >>= f = f a

simp rule for Except-do chains: binding an .ok value is just function application.

This keeps proof scripts around buildFrom readable when large do blocks are unfolded.

@[simp]
theorem Runtime.Autograd.Compiled.Except.error_bind {ε α β : Type} (e : ε) (f : αExcept ε β) :

simp rule for Except-do chains: binding an .error short-circuits.

Used heavily when discharging impossible branches in compilation correctness proofs.

@[simp]
theorem Runtime.Autograd.Compiled.Except.bind_ok {ε α β : Type} (a : α) (f : αExcept ε β) :
(Except.ok a).bind f = f a

Definitional simplification for Except.bind on .ok.

@[simp]
theorem Runtime.Autograd.Compiled.Except.bind_error {ε α β : Type} (e : ε) (f : αExcept ε β) :

Definitional simplification for Except.bind on .error.

A forward-executable SSA graph derived from an NN.IR.Graph.

The compiled graph stores:

  • one distinguished input shape (inShape),
  • one shape per compiled node (ss, corresponding to IR node ids 1..n-1),
  • and executable node closures (g) consumed by GraphData.eval.
Instances For

    Evaluate the compiled executable SSA graph on a concrete input tensor.

    The result is the full typed runtime context [inShape] ++ ss, i.e. input followed by every compiled node value in topological order.

    Instances For

      Denotation Table Helper #

      ExecGraphData.eval produces a typed runtime context TList α ([inShape] ++ ss).

      For debugging and for the forward-correctness development in NN.Runtime.Autograd.Compiled.IRExec.Correctness, we provide a helper that erases this context into an IR-style value table Array (NN.IR.DVal α) in node-id order.

      Convert a runtime AnyTensor (shape carried as a field) into an IR denotation value DVal.

      Instances For

        Convert a typed runtime context TList α ss into an IR-style value table.

        This is phrased in terms of Array (DVal α) because the IR denotation functions (denoteAll*) are array-based, while the compiled runtime evaluates into a typed context (TList).

        Instances For

          Convert the full evaluated context into an IR-style value table (one DVal per node id).

          This is the concrete bridge used in semantic equivalence statements that compare compiled evaluation against NN.IR.Graph.denoteAll*.

          Instances For
            @[reducible, inline]

            Internal compilation state used by buildFrom.

            It is a dependent pair of:

            • ss: shapes of already-compiled IR nodes,
            • GraphData α Unit [inShape] ss: executable closures for exactly that shape list.
            Instances For

              Build a typed runtime index (Idx) for a numeric IR parent id.

              The compiled runtime context is typed by a list of shapes [inShape] ++ ss. mkIdx checks that:

              • id is in bounds, and
              • the context shape at that position matches the expected shape s.

              On failure, this returns a descriptive error string used directly by buildFrom.

              Instances For

                Construct a NodeData for forward execution only.

                The compiled runtime GraphData expects each node to supply forward, jvp, and vjp. For this IR bridge we only care about forward correctness, so jvp/vjp are populated with forward-only sentinels that panic! if called.

                This is intentional: IRExec closes the forward semantics gap; full gradient behavior is handled by other runtime/autograd layers. Using panic! here is a safety measure: it prevents silently wrong gradients if someone accidentally routes differentiation through an IRExec-compiled graph.

                Instances For
                  @[simp]

                  Forward projection for mkFwdNode.

                  The JVP/VJP fields are sentinels in this bridge, but the forward field is exactly the function passed to the constructor. This small simp lemma is used by the IR semantic-equivalence proof.

                  Apply a list of adjacent swaps (specified by swap depths) to a shape.

                  This is the shape-level companion of applySwapsTensor, and mirrors IR permutation lowering.

                  Instances For

                    Apply the same swap sequence as swapShapeBySwaps, but to a tensor value.

                    This uses Tensor.swap_at_depth_helper repeatedly; it is the runtime companion of the IR-side swapDepthsForPerm lowering used by .permute.

                    Instances For

                      Concatenate a list of tensors (all with shape .dim nP rest) along dimension 0.

                      The input list is expressed as typed indices into the runtime context Γ; the result tracks the total concatenated size as a sigma.

                      This helper supports lowering of IR concat-style operators while preserving shape information.

                      Instances For
                        theorem Runtime.Autograd.Compiled.IRExec.concatDim0FromInfos_fst_eq_sum {α : Type} [Context α] {Γ : List Spec.Shape} {rest : Spec.Shape} (ctx : Proofs.Autograd.Algebra.TList α Γ) (infos : List ((nP : ) × Proofs.Autograd.Algebra.Idx Γ (Spec.Shape.dim nP rest))) :
                        (concatDim0FromInfos ctx infos).fst = List.foldl (fun (acc : ) (info : (nP : ) × Proofs.Autograd.Algebra.Idx Γ (Spec.Shape.dim nP rest)) => acc + info.fst) 0 infos

                        The concatenated size reported by concatDim0FromInfos is the sum of the input sizes.

                        This theorem is used to justify the output-shape side conditions in concat lowering branches.

                        @[irreducible]
                        def Runtime.Autograd.Compiled.IRExec.buildFrom {α : Type} [Context α] [DecidableEq Spec.Shape] (g : NN.IR.Graph) (payload : NN.IR.Payload α) (inShape : Spec.Shape) (i : ) (st : State α inShape) :
                        Except String (State α inShape)

                        Compile the IR graph starting at node index i, extending the current SSA State.

                        This is the main compiler loop:

                        • it checks i < g.nodes.size,
                        • compiles node i into a NodeData.forward closure (rejecting unsupported ops/shapes), and
                        • snocs the resulting node into the accumulating GraphData.

                        The public entrypoint execGraphOfIR handles node 0 and calls buildFrom starting at i = 1.

                        Operationally, buildFrom is a checked compiler:

                        • success means every visited node had well-typed parents and a supported lowering case,
                        • failure returns a concrete error explaining the first unsupported/malformed node.
                        Instances For

                          Compile an op-tagged IR graph into an executable SSA graph (GraphData) for forward evaluation.

                          Requirements:

                          • Node id 0 must be .input.
                          • The graph must satisfy Graph.checkWellFormed.
                          • The external payload must contain entries for every .const/.linear/.conv2d node id.

                          This returns an ExecGraphData whose eval computes all node values in topo order.

                          This is the main API consumed by runtime callers that want executable evaluation while remaining aligned with the shared NN.IR.Graph semantics.

                          Instances For