TorchLean API

NN.IR.Semantics

Semantics #

Denotational semantics for NN.IR.Graph.

This file defines an evaluator for the current IR fragment:

The evaluator is total on well-formed, well-shaped graphs and returns Except String on malformed graphs or missing payloads.

Softmax and layer norm:

How this relates to PyTorch:

References / related systems:

Parameter payloads #

The IR graph stores OpKind and outShape, but it does not embed tensor values for parameters. Instead, evaluation is parameterized by a Payload:

This matches how most graph formats work in practice: structure is one artifact, parameters are another.

structure NN.IR.ConstFlat (α : Type) [Context α] :

Payload record for a const node.

Constants are stored in a “flat” (1-D) representation so backends can keep a uniform container (e.g. an array). During evaluation we check the flat length against Shape.size and then unflatten to the requested Shape.

Instances For
    structure NN.IR.LinearWB (α : Type) [Context α] :

    Payload record for a linear node: weight matrix W and bias vector b.

    The node's input x comes from the graph edge; W,b live in the external Payload (similar to ONNX initializers or a PyTorch state_dict).

    Instances For
      structure NN.IR.Conv2DParams (α : Type) [Context α] :

      Payload record for a conv2d node.

      We store the spec-layer Conv2DSpec together with the dimension parameters needed to reconstruct it. The nonzero proofs are required by the spec-layer definition and ensure the convolution is well-formed.

      • inC :

        Input channels.

      • outC :

        Output channels.

      • kH :

        Kernel height.

      • kW :

        Kernel width.

      • stride :

        Stride.

      • padding :

        Padding size.

      • inH :

        in H.

      • inW :

        Input width.

      • hIn : self.inC 0

        Proof that the input channel count is nonzero, required by the spec convolution layer.

      • hKH : self.kH 0

        Proof that the kernel height is nonzero.

      • hKW : self.kW 0

        Proof that the kernel width is nonzero.

      • spec : Spec.Conv2DSpec self.inC self.outC self.kH self.kW self.stride self.padding α

        Spec-layer convolution package containing weights, bias, and convolution metadata.

      Instances For
        structure NN.IR.Payload (α : Type) [Context α] :

        External parameter payloads keyed by IR node id.

        This is deliberately small; different backends may store parameters differently.

        Instances For

          Dynamic (shape-tagged) values #

          During evaluation we keep values in a dependent pair Σ s, Tensor α s so we can store a heterogenous table of intermediate tensors while still recovering precise shapes when we need them.

          @[reducible, inline]
          abbrev NN.IR.DVal (α : Type) [Context α] :

          Dynamic (shape-tagged) tensor value used by the IR evaluator.

          This is a dependent pair Σ s, Tensor α s, which lets us store heterogeneously-shaped intermediate values in one table while still recovering exact shapes when needed.

          Instances For
            def NN.IR.DVal.shape {α : Type} [Context α] (v : DVal α) :

            The shape tag carried by a dynamic value.

            Instances For
              def NN.IR.DVal.tensor {α : Type} [Context α] (v : DVal α) :

              The underlying tensor, with its shape recovered from the dependent pair.

              Instances For
                def NN.IR.DVal.mk {α : Type} [Context α] (s : Spec.Shape) (t : Spec.Tensor α s) :
                DVal α

                Construct a dynamic value from a shape and a tensor of that shape.

                Instances For

                  Small proof-helpers used by evaluation #

                  The evaluator frequently needs evidence that an axis is valid or that a broadcast is legal so it can call the spec-layer operations, which are typed with these preconditions. We build these witnesses from runtime data (Nat axis values and shapes) using Option:

                  Build a witness that axis is a valid axis for shape s.

                  Many spec-layer ops (e.g. reductions, softmax, layernorm) are typed with a Shape.valid_axis precondition. Since the IR stores axes as raw Nat, we reconstruct the witness at runtime.

                  Returns none when axis is out of bounds.

                  Instances For

                    Build a witness that s₁ can be broadcast to s₂ (NumPy/PyTorch-style broadcasting).

                    The spec-layer broadcasting operator is typed with Shape.CanBroadcastTo. Since the IR stores only runtime shapes, we reconstruct this witness on demand.

                    Returns none when broadcasting is not possible.

                    Instances For

                      Return the index of the first occurrence of x in xs (or none if absent).

                      Instances For
                        Instances For
                          def NN.IR.Graph.listGet? {α : Type} :
                          List αOption α

                          Safe list indexing: listGet? xs n returns some xs[n] when in bounds.

                          Instances For

                            Swap the adjacent entries at positions d and d+1 (no-op when out of range).

                            Instances For

                              Compute a sequence of adjacent swaps that realizes a target permutation.

                              This is used to implement .permute by repeatedly applying swapAdjacentAtDepth, which is already available in the spec tensor library. If the permutation is ill-formed, this returns an error explaining what went wrong.

                              Instances For
                                def NN.IR.Graph.applySwapDepth {α : Type} [Context α] (v : DVal α) (d : ) :
                                DVal α

                                Apply one adjacent-swap-at-depth to a dynamic tensor value.

                                This is the execution-level building block used to implement .permute in terms of repeated adjacent swaps, reusing the spec tensor library's swapAdjacentAtDepth.

                                Instances For
                                  def NN.IR.Graph.permuteDVal {α : Type} [Context α] (v : DVal α) (perm : List ) :

                                  Permute a dynamic tensor value according to perm.

                                  This checks that perm is a valid permutation for the input shape (using Shape.permute?), then lowers it to a sequence of adjacent swaps and applies them to the tensor.

                                  Instances For

                                    Evaluation helpers #

                                    The evaluator itself (evalAt / denoteAll) is a fold over nodes. These helpers keep the fold readable:

                                    def NN.IR.Graph.expectShape {α : Type} [Context α] [DecidableEq Spec.Shape] (expected : Spec.Shape) (v : DVal α) :
                                    Except String (Spec.Tensor α expected)

                                    Check a dynamic value has the expected shape and return it as a statically-typed tensor.

                                    Instances For
                                      def NN.IR.Graph.mseLossDVal {α : Type} [Context α] [DecidableEq Spec.Shape] (i : ) (yVal tVal : DVal α) :

                                      Evaluate MSE loss on two dynamic values, checking that their runtime shapes agree.

                                      Instances For

                                        Transport a Tensor α (dim n scalar) across an equality n = n' (helper for payload casts).

                                        Instances For
                                          def NN.IR.Graph.evalConst {α : Type} [Context α] [Inhabited α] (payload : Payload α) (id : ) (s : Spec.Shape) :

                                          Evaluate a const node from the external payload.

                                          Constants are stored “flat” (1D) for convenience, so we check the flattened length matches Shape.size s and then unflatten to the requested shape.

                                          Instances For
                                            def NN.IR.Graph.evalLinear {α : Type} [Context α] [DecidableEq Spec.Shape] (payload : Payload α) (id : ) (x : DVal α) (outShape : Spec.Shape) :

                                            Evaluate a linear node from the external payload.

                                            We enforce:

                                            • the input dynamic value has shape (inDim), and
                                            • the node's declared outShape matches (outDim).

                                            The actual math is the usual affine map: y = W·x + b.

                                            Instances For
                                              def NN.IR.Graph.evalConv2D {α : Type} [Context α] [DecidableEq Spec.Shape] (payload : Payload α) (id : ) (x : DVal α) :

                                              Evaluate a conv2d node from the external payload.

                                              The output shape is computed with the standard (no dilation) formula: out = ⌊(in + 2*pad - k)/stride⌋ + 1 for each spatial dimension.

                                              Instances For

                                                Deterministic LayerNorm used by the IR evaluator (gamma=1, beta=0).

                                                Instances For
                                                  def NN.IR.Graph.evalAt {α : Type} [Context α] [Inhabited α] [DecidableEq Spec.Shape] (g : Graph) (payload : Payload α) (input : DVal α) (vals : Array (DVal α)) (i : ) :

                                                  Evaluate node i given already computed parent values vals.

                                                  This is the core “one step” of the denotational semantics:

                                                  • lookup the node,
                                                  • read its parent values from vals (using the topo/id invariant),
                                                  • apply the corresponding spec-layer operation,
                                                  • enforce that the produced shape matches the node’s declared outShape.

                                                  This function assumes the graph is structurally well-formed (ids are in bounds and parents are strictly smaller ids). denoteAll performs that check up front.

                                                  Instances For
                                                    @[irreducible]
                                                    def NN.IR.Graph.denoteAllFrom {α : Type} [Context α] [Inhabited α] [DecidableEq Spec.Shape] (g : Graph) (payload : Payload α) (input : DVal α) (i : ) (vals : Array (DVal α)) :

                                                    Evaluate nodes i, i+1, ... given already computed prefix values vals.

                                                    This is written as a structurally recursive function so it is easy to reason about in proofs (evaluation is “a simple loop over node ids”).

                                                    Instances For
                                                      def NN.IR.Graph.denoteAll {α : Type} [Context α] [Inhabited α] [DecidableEq Spec.Shape] (g : Graph) (payload : Payload α) (input : DVal α) :

                                                      Evaluate a graph to a table of node values.

                                                      This returns an array vals of length g.size where vals[i] is the value of node i.

                                                      We do a structural well-formedness check once up front (ids/arity/topology). For compiler-produced graphs, the boolean Graph.wellFormed check is a fast path; if it fails we fall back to the exception-producing Graph.checkWellFormed so callers get a readable error message.

                                                      The evaluator is total in the sense that it always returns either:

                                                      • .ok vals (all nodes evaluated successfully), or
                                                      • .error msg describing the first failure (malformed IR, missing payload, or a local shape error).
                                                      Instances For

                                                        Scoped notation #

                                                        Scoped notation for evaluating a graph to all node values.

                                                        Use with:

                                                        open scoped IR
                                                        g⟦payload, input⟧
                                                        
                                                        Instances For

                                                          ASCII alternative to g⟦payload, input⟧.

                                                          Instances For
                                                            def NN.IR.Graph.denote {α : Type} [Context α] [Inhabited α] [DecidableEq Spec.Shape] (g : Graph) (payload : Payload α) (input : DVal α) (outputId : ) :

                                                            Evaluate the graph and return the value at outputId.

                                                            Instances For