TorchLean API

NN.GraphSpec.Core

Sequential GraphSpec #

This file defines the sequential authoring surface for GraphSpec.

The important design decision is:

So Graph is not a competing graph IR. It is a pleasant way to write:

Linear >>> ReLU >>> Linear
Conv >>> ReLU >>> Pool >>> Flatten >>> Linear

and then lower that chain to the general DAG representation when downstream tooling wants one model shape for everything.

GraphSpec as a whole is a typed DSL for describing neural-network computations, with the explicit goal of being usable in two complementary ways:

  1. Reference / proof semantics: interpret the graph as a pure Lean function on tensors (Interp.spec). This is the semantics we want to reason about: shape safety, algebraic identities, equivalence of model refactorings, etc.
  2. Executable semantics: compile the same graph into a backend-generic TorchLean.Program (Compile.torchProgram) so it can run on the TorchLean runtime (which can target eager or compiled execution backends).

Shapes and parameter shapes are part of the graph type. Concretely, a graph is indexed by:

That “parameter interface” is not a convention (like “whatever state_dict() happens to return”); it is baked into the model type. Sequential composition concatenates parameter lists, and evaluation splits them canonically.

Why do this if PyTorch already exists? #

PyTorch is excellent at running and training neural networks:

GraphSpec is not trying to replace any of that. Instead, it focuses on the pieces PyTorch does not give us inside Lean:

In practice, the expected workflow is:

Why do this if TorchLean already exists? #

TorchLean is the runtime and operator layer: it gives us typed tensors, a backend interface, and executable programs (TorchLean.Program) that can run under the autograd/training runtime.

GraphSpec is the architecture/specification layer: it gives us a small typed syntax for model structure that comes with two linked meanings:

You can write models directly in TorchLean, but then the “thing you reason about” is already in the executable world (monadic references + backend ops). For many proofs, it is much cleaner to reason about a pure function Params → Tensor → Tensor and separately prove that compilation to the runtime preserves that meaning.

In other words:

Mathematical View For Sequential Chains #

For g : Graph ps σ τ, think of g as denoting a function

⟦g⟧ : Params(ps) → Tensor σ → Tensor τ.

In this file, that semantics is implemented by Interp.spec, and it is defined structurally:

The compiler Compile.torchProgram follows the same structure, but targets a monadic Torch interface and expects arguments as params ++ [input] (matching TorchLean.NN.Seq.program).

Scope of Core.lean #

This file defines only the sequential core:

For skip connections, shared intermediates, residual adds, or other multi-input nodes, use NN.GraphSpec.DAG directly.

Direction #

GraphSpec is intended to grow into a hygienic “write once, run/prove many” layer:

References / citations #

Core graph language #

structure NN.GraphSpec.Primitive (ps : List Shape) (σ τ : Shape) :

A primitive node in the GraphSpec language.

GraphSpec primitives package both sides of the “spec vs runtime” interface:

  • a pure spec forward function (specFwd) used by the reference interpreter, and
  • a TorchLean program (torchProgram) used by the compiler.

Optionally, a primitive may also provide a lowering to a TorchLean LayerDef (used to build a TorchLean.NN.Seq for training ergonomics + deterministic parameter initialization). Not every primitive needs this (e.g. control-flow-ish nodes kept outside the sequential layer).

Why a record?

  • It lets us grow the op set by adding new primitives in new files (rather than editing a single global inductive just to extend the vocabulary).
  • It keeps the “spec vs TorchLean” linkage explicit: when you add an op, you must define both interpretations side-by-side.

Type indices:

  • ps : List Shape are the parameter tensor shapes this primitive expects, in order.
  • σ τ : Shape are input/output tensor shapes.
Instances For
    inductive NN.GraphSpec.Graph :
    List ShapeShapeShapeType 2

    Graph ps σ τ is a (restricted) model that:

    • takes an input tensor of shape σ,
    • produces an output tensor of shape τ,
    • and uses parameters whose shapes are listed in ps (in order).

    This is a sequential (chain) graph language: the only composition operator is seq (>>>). For sharing/skip connections, use NN.GraphSpec.DAG.

    Implementation note:

    • We encode the parameter list at the type level so composition automatically concatenates parameter lists (ps := ps₁ ++ ps₂).
    • This means every graph has a canonical “ABI” for parameters: a single typed list TList α ps. When composing g₁ : Graph ps₁ σ τ and g₂ : Graph ps₂ τ υ, the composite graph expects parameters of shape list ps₁ ++ ps₂, and evaluation splits that list into the pieces needed by each subgraph.
    • id (s : Shape) : Graph [] s s

      Identity graph: passes the input through unchanged and requires no parameters.

    • seq {ps₁ ps₂ : List Shape} {σ τ υ : Shape} : Graph ps₁ σ τGraph ps₂ τ υGraph (ps₁ ++ ps₂) σ υ

      Sequential composition. Parameter lists concatenate.

    • prim {ps : List Shape} {σ τ : Shape} : Primitive ps σ τGraph ps σ τ

      Embed a single primitive node as a graph.

    Instances For

      Standard primitives (initial op set) #

      Primitive linear layer.

      Mathematical semantics (vector case):

      Let x : Vec inDim, W : Mat outDim inDim, and b : Vec outDim. Then:

      linear(W,b,x) = W * x + b.

      This matches the standard dense layer as in PyTorch torch.nn.Linear / torch.nn.functional.linear (up to the usual row/column convention; here the shape indices make the intended dimensions explicit).

      Type-level parameter interface:

      • parameter shapes are [Mat outDim inDim, Vec outDim],
      • input shape is Vec inDim,
      • output shape is Vec outDim.

      So a graph containing a linear node forces you to supply exactly a weight matrix and bias vector of the right shapes, and it fixes their ordering in the model’s parameter list.

      References:

      • Dense layers are standard; for PyTorch behavior see torch.nn.Linear documentation.
      • For the semantics used by the spec interpreter, see NN.Spec.Module.Linear (Spec.linear_spec).

      Initialization semantics:

      • we attach a TorchLean LayerDef so graphs can be lowered to TorchLean.NN.Seq,
      • and we seed W,b deterministically from the layer-occurrence index:
        • seedW = 2*i, seedB = 2*i + 1.

      The deterministic occurrence-index rule keeps end-to-end examples reproducible while preserving a single GraphSpec → TorchLean → training path.

      Instances For

        ReLU activation (parameter-free).

        Mathematical semantics: elementwise relu(x) = max(x, 0).

        This is shape-preserving and parameter-free, so its parameter list is [] and its input/output shape indices are both s.

        References:

        • Nair & Hinton (2010), “Rectified Linear Units Improve Restricted Boltzmann Machines”.
        • Spec definition: Activation.relu_spec in NN.Spec.Layers.Activation.
        Instances For

          Last-axis softmax (parameter-free).

          Softmax turns “logits” into a probability distribution along the last axis:

          softmax(x)_i = exp(x_i) / (∑_j exp(x_j)).

          In TorchLean’s spec layer, this is implemented as a genuine last-axis tensor softmax (recursing over outer dimensions), analogous to torch.softmax(x, dim=-1) in PyTorch.

          Notes:

          • Softmax is not elementwise; it normalizes across an axis, so it is a canonical example of a non-pointwise activation.
          • For numerical stability, practical implementations often rewrite softmax using logsumexp. The spec semantics here follows the dedicated Activation.softmax_spec.

          References:

          Instances For

            Graph constructor for Primitive.linear.

            Instances For

              Graph constructor for Primitive.relu.

              Instances For

                Graph constructor for Primitive.softmax.

                Instances For

                  Lowering: sequential Graph → DAG term/model #

                  GraphSpec has two surface syntaxes:

                  The DAG term/model language is GraphSpec’s “general graph” core: it is the representation that can express sharing and skip connections.

                  Sequential Graph exists because it is the clearest way to write pipelines, and it has its own direct Spec semantics (Interp.spec) and compiler (Compile.torchProgram).

                  This lowering is still useful whenever you want to embed a sequential pipeline into the DAG world (e.g. to reuse DAG-only tooling, or to keep a single GraphSpec example surface that can export DAG models).

                  This section provides a structural lowering:

                  Notes:

                  Lowering internals #

                  The definitions in this section (castTerm, toTerm, …) are internal adapters for the structural lowering. The intended public API is Graph.toDAGTerm / Graph.toDAGModelZeroInit.

                  def NN.GraphSpec.LowerToDAG.castTerm {Γ : List Shape} {s t : Shape} (h : s = t) :
                  DAG.Term Γ sDAG.Term Γ t

                  Cast a DAG.Term across a proven equality of output shapes.

                  Instances For
                    def NN.GraphSpec.LowerToDAG.castEnvTerm {Γ Γ' : List Shape} {τ : Shape} (h : Γ = Γ') :
                    DAG.Term Γ τDAG.Term Γ' τ

                    Cast the environment of a DAG.Term across a proven equality of environments.

                    Instances For
                      def NN.GraphSpec.LowerToDAG.castEnvArgs {Γ Γ' ins : List Shape} (h : Γ = Γ') :
                      DAG.Args Γ insDAG.Args Γ' ins

                      Cast the environment of DAG.Args across a proven equality of environments.

                      Instances For

                        List.get lemmas (small, self-contained) #

                        theorem NN.GraphSpec.LowerToDAG.get_append_left_nat {α : Type} (as bs : List α) (i : ) (hi : i < as.length) :
                        (as ++ bs).get i, = as.get i, hi

                        List.get into as is unchanged by appending a right list (Nat-index form).

                        theorem NN.GraphSpec.LowerToDAG.get_append_right_offset_nat {α : Type} (as bs : List α) (j : ) (hj : as.length + j < (as ++ bs).length) :
                        (as ++ bs).get as.length + j, hj = bs.get j,

                        List.get into the right list after appending, using an explicit offset as.length + j (Nat-index form).

                        theorem NN.GraphSpec.LowerToDAG.get_append_last {α : Type} (xs : List α) (x : α) :
                        (xs ++ [x]).get xs.length, = x

                        List.get of the last element after appending a singleton list.

                        Primitive embedding: PrimitiveDAG.PrimOp #

                        Embed a sequential GraphSpec primitive as a DAG primitive op.

                        The resulting op has input shapes ps ++ [σ] (parameters followed by the data input).

                        Instances For

                          Building well-typed DAG arguments for a primitive call #

                          theorem NN.GraphSpec.LowerToDAG.get_succ {α : Type} (a : α) (as : List α) (i : Fin as.length) :
                          (a :: as).get i + 1, = as.get i
                          def NN.GraphSpec.LowerToDAG.argsOfFn {Γ ins : List Shape} :
                          ((i : Fin ins.length) → DAG.Term Γ (ins.get i))DAG.Args Γ ins

                          Build a typed DAG.Args list from an index-based family of argument terms.

                          This is the bridge from “arguments as a function of Fin ins.length” to the inductive DAG.Args encoding used by DAG.Term.op.

                          Instances For
                            def NN.GraphSpec.LowerToDAG.Args.append1 {Γ ps : List Shape} {σ : Shape} :
                            DAG.Args Γ psDAG.Term Γ σDAG.Args Γ (ps ++ [σ])

                            Append one final term to a typed DAG argument list.

                            Instances For
                              def NN.GraphSpec.LowerToDAG.mkParamTerm {pre ps post extra : List Shape} (i : Fin ps.length) :
                              DAG.Term (pre ++ ps ++ post ++ extra) (ps.get i)

                              Reference the ith parameter block inside a larger environment layout.

                              The surrounding environment is split as pre ++ ps ++ post ++ extra; this helper returns the term that points at parameter i : Fin ps.length while keeping the full ambient environment explicit.

                              Instances For
                                def NN.GraphSpec.LowerToDAG.primCall {pre ps post extra : List Shape} {σ τ : Shape} (p : Primitive ps σ τ) (x : DAG.Term (pre ++ ps ++ post ++ extra) σ) :
                                DAG.Term (pre ++ ps ++ post ++ extra) τ

                                Lower a unary Primitive application into the DAG term language.

                                Parameters are read from the middle ps segment of the ambient environment, in the same order as the primitive's parameter ABI, and the final data input is supplied by x.

                                Instances For

                                  Graph lowering #

                                  def NN.GraphSpec.LowerToDAG.toTerm {pre ps post extra : List Shape} {σ τ : Shape} (g : Graph ps σ τ) (x : DAG.Term (pre ++ ps ++ post ++ extra) σ) :
                                  DAG.Term (pre ++ ps ++ post ++ extra) τ

                                  Lower a sequential Graph to an SSA-style DAG.Term, with parameters read from the environment.

                                  Instances For

                                    Public API #

                                    Initialize a parameter list by filling every tensor with zeros (useful for proofs and examples).

                                    Instances For

                                      Deterministic initialization for sequential graphs #

                                      Graph.toDAGModelZeroInit is total, but its parameters are all-zero tensors, which is convenient for proofs and shape-only examples but not representative of training setups.

                                      For graphs whose primitives provide Primitive.toLayerDefM?, we can reuse TorchLean’s deterministic initializers (e.g. Xavier init for linear weights) in a way that matches ToTorchLean.toSeq:

                                      We expose this as Graph.toDAGModelDetInit? : Except String (DAG.Model ...): it fails if any primitive lacks a toLayerDefM? lowering.

                                      Compute deterministic initialization tensors for a sequential Graph, threading a “layer occurrence index”.

                                      This matches ToTorchLean.toSeq’s notion of “occurrence”: only primitives with countsAsLayer = true advance the counter.

                                      Instances For

                                        Convenience wrapper: start the occurrence index at 0 and discard its final value.

                                        Instances For
                                          def NN.GraphSpec.LowerToDAG.Graph.toDAGTerm {ps : List Shape} {σ τ : Shape} (g : Graph ps σ τ) :
                                          DAG.Term (ps ++ [σ]) τ

                                          Lower a sequential Graph to a DAG term with environment ps ++ [σ].

                                          Instances For

                                            Lower a sequential Graph to a DAG Model with a simple default init (all zeros).

                                            This is mainly a convenience for GraphSpec example organization; for training-oriented init, see NN.GraphSpec.ToTorchLean (Seq lowering) and/or provide your own initializer.

                                            Instances For

                                              Lower a sequential Graph to a DAG Model, using deterministic initialization.

                                              This is the DAG analogue of ToTorchLean.toSeq’s initialization semantics: it uses each primitive’s toLayerDefM? to obtain a TorchLean LayerDef, then reuses the LayerDef.initParams.

                                              This returns Except String because not every primitive necessarily admits a LayerDef lowering.

                                              Instances For

                                                Semantics (sequential core) #

                                                The sequential DSL (Graph with >>>) has direct semantics:

                                                Even though a sequential graph is semantically a path-shaped DAG, we keep the sequential interpreter/compiler direct for two pragmatic reasons:

                                                1. Proof ergonomics. For chain graphs, definitional reduction is much simpler when we evaluate directly rather than going through an SSA lowering.
                                                2. Engineering clarity. The sequential and DAG languages have different invariants (parameter concatenation vs explicit let1 sharing). Keeping each semantics close to its syntax makes the code easier to audit.

                                                We still provide a structural lowering LowerToDAG.Graph.toDAGModelZeroInit so that DAG-only tooling can consume sequential models. The DAG path becomes the canonical execution route when a caller wants explicit sharing together with the corresponding simp lemmas / proof infrastructure.

                                                @[reducible, inline]

                                                A typed list of parameter tensors matching the parameter-shape ABI ps.

                                                Instances For

                                                  Split a typed parameter list for a sequential composition.

                                                  If ps = ps₁ ++ ps₂, then a value of type TList α ps can be split into the prefix parameters for the left subgraph and the remaining parameters for the right subgraph.

                                                  Instances For
                                                    def NN.GraphSpec.Interp.spec {ps : List Shape} {σ τ : Shape} (g : Graph ps σ τ) {α : Type} [Context α] :
                                                    Params α psTensor.Tensor α σTensor.Tensor α τ

                                                    Pure Spec semantics of a sequential Graph.

                                                    Instances For

                                                      Compile a sequential Graph to a backend-generic TorchLean Program.

                                                      Instances For
                                                        def NN.GraphSpec.Compile.torchProgram.compileRefList {α : Type} [Context α] [DecidableEq Shape] {m : TypeType} (_instM : Monad m) (_instOps : Runtime.Autograd.Torch.Ops m α) {ps : List Shape} {σ τ : Shape} (g : Graph ps σ τ) (rs : Runtime.Autograd.Torch.RefList (fun (s : Spec.Shape) => Runtime.Autograd.Torch.Ops.Ref m α s) (ps ++ [σ])) :
                                                        have Ref := fun (s : Spec.Shape) => Runtime.Autograd.Torch.Ops.Ref m α s; m (Ref τ)
                                                        Instances For