LinkedSession #
Proof-linked imperative Session (eager-style API, proved IR under the hood).
Background:
Runtime.Autograd.TorchLean.Sessionprovides a unified imperative API for training/debugging (eager) and verification-friendly execution (compiled).Proofs.Autograd.Algebra.GraphDatais the proved/typed SSA(DAG) IR used by the proof-compiled pipeline (Proofs.Autograd.Algebra.Graph.compileAuxData), andNN/Proofs/Autograd/Runtime/Link.leanproves that running the runtime reverse-mode loop on the compiled tape matchesGraphData.backpropAllCtx.
This file provides a session-style API that records a GraphData (well-typed IR) as you call
ops imperatively, and then runs the standard runtime tape loop on the compiled tape.
Key guarantee (pure theorem, no IO reasoning needed):
- If the session snapshot is
(g, x), thenTape.backwardDenseFrom (compileAuxData g x)equalsGraphData.backpropAllCtx g x(viabackwardDenseFrom_compileAuxData_eq_backpropAllCtx).
Practical note:
- This session enforces a simple invariant: all leaf tensors are created before any op node. This matches the standard training pattern (reset → add leaves → forward → backward).
constis available as a graph node, so you can still introduce literal constants mid-graph.- This is the fully proof-linked variant used by
TorchLean.Sessionwhenopts.backend := .compiled.
Convenience: turn a Result α into IO α by throwing IO.userError on .error.
This mirrors the common pattern in the eager runtime front-end (Torch.Core).
Instances For
Non-differentiable external environment for the proved graph: a small array of Nat inputs.
Instances For
Internal proof-linked session state (a well-typed GraphData plus its leaf values).
- Γ : List Spec.Shape
Leaf shapes (inputs/parameters), in creation order.
- x : Proofs.Autograd.Algebra.TList α self.Γ
Leaf values, aligned with
Γ. - nat : NatEnv
Non-differentiable external inputs (e.g. class labels/indices).
- ss : List Spec.Shape
Internal node shapes, in creation order.
- g : Proofs.Autograd.Algebra.GraphData α NatEnv self.Γ self.ss
SSA/DAG graph nodes (one per entry in
ss).
Instances For
Empty session state: no leaves, no nodes, empty nat-environment.
Instances For
SessionIR is an imperative session that records a GraphData (proved IR) as it runs.
It is "eager-style" (you call ops imperatively), but it is proof-linked: the recorded graph can be
compiled and then the runtime tape backward loop is provably equal to GraphData.backpropAllCtx.
- opts : Options
Session options shared with the eager front-end.
- st : IO.Ref (SessionIRState α)
Mutable proof-linked graph snapshot.
- paramsByLeaf : IO.Ref (Std.HashMap ℕ (AnyParam α))
Map from graph leaf ids to mutable parameter objects.
Instances For
Create a new proof-linked session.
This allocates IO.Refs for the session snapshot (SessionIRState) and the leaf-id→parameter map.
Call resetTape to start a new "graph recording" phase.
Instances For
Reset the session to an empty snapshot.
Important invariant: this session requires that all leaves are created before any op node.
resetTape is the intended boundary between training steps/forwards.
Instances For
Create a mutable parameter object (not yet part of the recorded graph).
To use the parameter in the recorded graph, call use, which reads its current value and records
it as a leaf in Γ.
PyTorch comparison: analogous to creating a torch.nn.Parameter and then using it in a forward.
Instances For
Enforce the session invariant: leaves must be created before any op node.
This keeps the GraphData context split Γ ++ ss easy to reason about and matches the typical
training pattern: resetTape → add leaves → forward ops → backward.
Instances For
Record a new differentiable leaf tensor in the session context Γ.
This is the primitive used by use (parameters) and input (external inputs).
Instances For
Use a Param in the recorded graph by reading its current value and recording it as a leaf.
The returned TensorRef is the graph handle you pass to subsequent ops. The session also remembers
which leaf-id corresponds to which parameter, so sgdStepAll can update parameters after backward.
PyTorch comparison: like referencing a torch.nn.Parameter in the forward; the parameter's value
is treated as a leaf for autograd.
Instances For
Record an external differentiable input tensor as a leaf.
name and requiresGrad are accepted for API parity with the eager session, but this proof-linked
session always records the input in Γ (a leaf) and uses typing/invariants to determine what
gradients are meaningful.
Instances For
Record a non-differentiable Nat input in the external environment.
This is used for "index-like" inputs (labels, gather indices, etc.) that should not receive gradients. PyTorch comparison: like passing an integer tensor / index to an op; indices are not differentiable.
Instances For
Convert a small Tensor Nat (.dim k .scalar) into an Array Nat.
This is used to stage NatVecRef inputs into the session nat-environment.
Instances For
Record a non-differentiable vector of Nat inputs.
Returns a NatVecRef k which points into the nat-environment. This is useful for "runtime gather"
style ops where indices are supplied externally (and are not differentiable).
Instances For
Read back the k-vector stored at a NatVecRef k.
Instances For
Overwrite the nat-environment segment referenced by NatVecRef k.
Instances For
Build a typed index into the current context Γ ++ ss from a raw numeric id and expected shape.
This is the main "dynamic check" used by getValue (and by a few index-driven nodes): it ensures
that the Nat id points to an existing tensor in the session context and that the shape matches.
Instances For
Evaluate the recorded graph and return the value of a TensorRef.
This is a pure graph evaluation (GraphData.eval) using the recorded leaf values and
nat-environment. It does not run the runtime tape or mutate session state.
Instances For
Graph-node ops (implemented by reusing Compiled.GraphM) #
Run a Compiled.GraphM computation against the current (ss, g) pair.
Compiled.GraphM is the builder monad used by the proof-friendly compiled pipeline; reusing it
here ensures this eager-style API records the same typed IR that the compiler expects.
Instances For
Atomically apply a graph-building update to the session snapshot.
This is the central adapter used by each op wrapper below: it reads s.st, runs a builder that
returns an updated SessionIRState, stores it back into s.st, and returns the op result.
Instances For
Record a constant tensor.
Subtlety: if no op nodes have been created yet (ss = []), we record const as a leaf to match
the eager session's leaf-collection behavior. Once op nodes exist, we emit an explicit constant node
so users can introduce literal constants mid-graph.
PyTorch comparison: like torch.tensor(...) (a leaf) vs inserting a literal constant into the
graph; constants are treated as non-requires-grad.
Instances For
Record elementwise addition a + b.
PyTorch comparison: torch.add(a, b) / the + operator.
Instances For
Record elementwise subtraction a - b.
PyTorch comparison: torch.sub(a, b) / the - operator.
Instances For
Record elementwise multiplication a * b.
PyTorch comparison: torch.mul(a, b) / the * operator.
Instances For
Record scaling by a scalar constant: x * c.
PyTorch comparison: like x * c (where c is a Python scalar).
Instances For
Record elementwise absolute value.
PyTorch comparison: torch.abs(x).
Instances For
Stop-gradient boundary.
Forward semantics: identity.
Backward semantics: no gradient flows to the input.
PyTorch comparison: x.detach().
Instances For
Record elementwise square root.
PyTorch comparison: torch.sqrt(x).
Instances For
Record elementwise clamp to the interval [minVal, maxVal].
PyTorch comparison: torch.clamp(x, min=minVal, max=maxVal).
Instances For
Record elementwise maximum of a and b.
PyTorch comparison: torch.maximum(a, b).
Instances For
Record elementwise minimum of a and b.
PyTorch comparison: torch.minimum(a, b).
Instances For
Record 2D matrix multiplication.
PyTorch comparison: torch.matmul(a, b) for 2D tensors.
Instances For
Record batched matrix multiplication.
PyTorch comparison: torch.bmm(a, b) for 3D tensors of shape (batch, m, n) and (batch, n, p).
Instances For
Concatenate two 1D vectors along dimension 0.
PyTorch comparison: torch.cat([a, b], dim=0) for 1D tensors.
Instances For
Concatenate two tensors along dimension 0.
PyTorch comparison: torch.cat([a, b], dim=0).
Instances For
Slice a tensor along dimension 0.
This returns x[start : start+len]. The proof argument h enforces bounds.
PyTorch comparison: x[start:start+len] for tensors with a leading dimension.
Instances For
N-D max-pooling for channels-first tensors (C, spatial...) (no batch axis).
PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the
spatial rank d.
Instances For
N-D smooth max-pooling (log-sum-exp surrogate) for channels-first tensors (C, spatial...).
This is a differentiable approximation of max-pooling; there is no direct PyTorch primitive.
Instances For
N-D average-pooling for channels-first tensors (C, spatial...) (no batch axis).
PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the
spatial rank d.
Instances For
2D max-pooling for channel-first images.
PyTorch comparison: torch.nn.functional.max_pool2d (for NCHW-like layouts, here without batch).
Instances For
Smooth approximation of max-pooling (softmax pooling) for channel-first images.
This is not a standard PyTorch primitive; conceptually it behaves like applying a softmax over each
pooling window with inverse-temperature beta and returning the expected value.
Instances For
2D average-pooling for channel-first images.
PyTorch comparison: torch.nn.functional.avg_pool2d (for NCHW-like layouts, here without batch).
Instances For
Record elementwise ReLU.
PyTorch comparison: torch.relu(x) / torch.nn.functional.relu(x).
Instances For
Flatten a tensor into a 1D vector of length Shape.size sh.
PyTorch comparison: torch.flatten(x) (with default start_dim=0).
Instances For
Reshape a tensor while preserving total number of elements.
The proof argument h enforces Shape.size sh1 = Shape.size sh2.
PyTorch comparison: torch.reshape(x, new_shape) / x.view(new_shape) (when contiguous).
Instances For
Transpose a 2D matrix (swap the two axes).
PyTorch comparison: x.t() for 2D tensors, or x.transpose(0, 1).
Instances For
Permute a 3D tensor by moving the first axis to the end: (a,b,c) → (b,c,a).
PyTorch comparison: x.permute(1,2,0) for a 3D tensor.
Instances For
Permute a 3D tensor by moving the last axis to the front: (a,b,c) → (c,a,b).
PyTorch comparison: x.permute(2,0,1) for a 3D tensor.
Instances For
Swap the last two axes of a 3D tensor: (a,b,c) → (a,c,b).
PyTorch comparison: x.transpose(1,2) for a 3D tensor.
Instances For
Swap two adjacent axes at a given depth inside the shape.
This is a more general permutation helper used in some shape-manipulating models.
PyTorch comparison: like x.transpose(dim, dim+1) for a suitably chosen dim.
Instances For
Broadcast a tensor to a larger shape.
The witness cb : Shape.CanBroadcastTo sh1 sh2 encodes the broadcasting compatibility proof.
PyTorch comparison: x.expand(...) / implicit broadcasting.
Instances For
Sum-reduce along axis.
PyTorch comparison: torch.sum(x, dim=axis).
Instances For
Mean-reduce along axis.
PyTorch comparison: torch.mean(x, dim=axis).
Instances For
Gather a single scalar x[i] from a 1D vector, with a compile-time Fin n index.
PyTorch comparison: x[i] for a 1D tensor.
Instances For
Gather a row x[i] from a 2D tensor, with a compile-time Fin rows index.
PyTorch comparison: x[i] for a 2D tensor (row indexing).
Instances For
Dynamic gather of a scalar from a 1D vector using a runtime NatRef index.
Out-of-range indices produce 0 instead of raising.
PyTorch comparison: similar to x[i] where i is a Python integer, except PyTorch raises on
out-of-range while this definition totalizes the behavior for ease of reasoning.
Instances For
Dynamic gather of a row from a 2D tensor using a runtime NatRef index.
Out-of-range indices yield a zero row.
PyTorch comparison: similar to x[i] for 2D tensors with runtime i, but PyTorch raises on
out-of-range whereas this definition is totalized for ease of reasoning.
Instances For
Dynamic gather of k scalars from a 1D tensor using a runtime NatVecRef k of indices.
Out-of-range indices yield 0. In the VJP, gradients are accumulated for repeated indices
(i.e. it behaves like a gather followed by a scatter-add back into the source vector).
PyTorch comparison: related to torch.gather / advanced indexing, but with totalized out-of-range
behavior.
Instances For
Dynamic gather of k rows from a 2D tensor using a runtime NatVecRef k of row indices.
Out-of-range indices yield zero rows. In the VJP, gradients are accumulated into the selected
rows (scatter-add semantics), including accumulation for repeated indices.
PyTorch comparison: similar to torch.index_select(x, dim=0, index=...) or advanced indexing on
the first dimension, but with totalized out-of-range behavior.
Instances For
Gather a scalar from a 1D vector using a raw Nat index.
PyTorch comparison: like x[i] with an integer index, but this operation is recorded into the
proved IR (so it is stable for compilation/verification).
Instances For
Gather k scalars from a 1D vector using an explicit index tensor.
PyTorch comparison: related to torch.gather / advanced indexing with an integer index tensor.
Instances For
Gather k rows from a 2D tensor using an explicit index tensor.
PyTorch comparison: similar to torch.index_select(x, dim=0, index=...) or advanced indexing.
Instances For
Scatter-add into a vector: return a copy of x with x[i] += v.
PyTorch comparison: similar to x.scatter_add_(dim=0, index=..., src=...) in spirit, but this is
functional (returns a new tensor) and uses a single Fin n index.
Instances For
Scatter-add into a matrix row: return a copy of x with x[i, :] += v.
PyTorch comparison: like adding a row vector into a selected row (functional analogue of an in-place indexed add).
Instances For
Record elementwise logistic sigmoid.
PyTorch comparison: torch.sigmoid(x).
Instances For
Record elementwise hyperbolic tangent.
PyTorch comparison: torch.tanh(x).
Instances For
Record softmax (shape-preserving).
PyTorch comparison: torch.softmax(x, dim=...). This helper uses the convention baked into the
underlying GraphM.softmax implementation.
Instances For
Record stable log-softmax in the linked compiled session.
This commits a single GraphM.logSoftmax node instead of expanding to softmax followed by
log, so compiled execution keeps the same stable semantics as eager CPU/CUDA.
Instances For
Record elementwise softplus.
PyTorch comparison: torch.nn.functional.softplus(x).
Instances For
Record elementwise exponential.
PyTorch comparison: torch.exp(x).
Instances For
Record elementwise natural logarithm.
PyTorch comparison: torch.log(x).
Instances For
Record elementwise log with epsilon guard.
This is intended for numerically stable losses; it corresponds roughly to log(max(x, ε)).
PyTorch comparison: torch.log(torch.clamp(x, min=ε)).
Instances For
Sum-reduce all elements to a scalar.
PyTorch comparison: x.sum().
Instances For
Record a fully-connected linear layer: y = w • x + b.
Type-level shapes enforce w : (outDim, inDim), b : (outDim,), and x : (inDim,).
PyTorch comparison: torch.nn.functional.linear(x, weight=w, bias=b) (with the same weight layout).
Instances For
Mean-squared-error loss returning a scalar.
PyTorch comparison: torch.nn.functional.mse_loss(yhat, target, reduction="mean").
Instances For
Layer normalization over the trailing embedding dimension.
This variant is specialized to 2D tensors of shape (seqLen, embedDim) and expects positive
dimensions for numerical stability and well-formedness.
PyTorch comparison: torch.nn.LayerNorm(embedDim) (applied per token), or
torch.nn.functional.layer_norm.
Instances For
Batch normalization for a channel-first image (C,H,W) (no batch axis).
gamma and beta are per-channel scale/shift parameters.
PyTorch comparison: torch.nn.BatchNorm2d(C) (conceptually), or torch.nn.functional.batch_norm
specialized to a single "batch element" with NCHW layout.
Instances For
N-D convolution for channels-first tensors (inC, spatial...) (no batch axis).
Kernel layout is (outC, inC, kernelSpatial...), bias is (outC).
PyTorch comparison: torch.nn.functional.conv{d}d specialized to a single sample.
Instances For
N-D transpose convolution for channels-first tensors (inC, spatial...) (no batch axis).
Kernel layout is (inC, outC, kernelSpatial...) (PyTorch convention), bias is (outC).
PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample.
Instances For
2D convolution for channel-first images (inC, inH, inW) (no batch axis).
Type-level shapes fix the kernel layout (outC, inC, kH, kW) and output spatial dimensions derived
from stride and padding.
PyTorch comparison: torch.nn.functional.conv2d (conceptually), specialized to a single image.
Instances For
2D transpose convolution for channel-first images (inC, inH, inW) (no batch axis).
Kernel layout matches the spec/PyTorch convention (inC, outC, kH, kW).
PyTorch comparison: torch.nn.functional.conv_transpose2d specialized to a single image.
Instances For
Multi-head self-attention.
This is a shape-specialized attention primitive used by some demo transformer-style models:
- input
xhas shape(n, dModel) wq,wk,wvmapdModel → numHeads*headDimwomapsnumHeads*headDim → dModel- optional
maskis a boolean(n,n)attention mask
PyTorch comparison: similar to torch.nn.MultiheadAttention / scaled dot-product attention, but
encoded in a fully typed IR for compilation/proof linkage.
Instances For
Backward + SGD (runtime tape loop on the compiled tape) #
Compile the recorded proved graph into a runtime tape.
This uses Graph.compileAuxData (the same compiler used by the proof pipeline) and extracts the
runtime tape component.
Instances For
Run reverse-mode backprop for the whole recorded context and return a dense gradient array.
seed is the upstream gradient for out (same convention as PyTorch's
loss.backward(gradient=...)).
Instances For
Convenience wrapper for scalar losses: run backward with seed 1.
PyTorch comparison: loss.backward() for a scalar loss.
Instances For
Extract the gradient tensor for a particular TensorRef from a dense gradient array.
This is the typed analogue of looking up grads[x.id] and casting it to the expected shape.
Instances For
Forward-mode: JVP (compiled only) #
Like mkIdxOrThrow, but restricted to leaves Γ only.
Instances For
Convert a dense tangent array (aligned with leaf creation order) into a typed TList α Γ.
This is the main adapter needed to call the proved GraphData.jvpCtx forward-mode routine.
Instances For
Instances For
Jacobian-vector product for the current session snapshot.
dxs is a dense array of tangents for leaf tensors, aligned with leaf creation order.
Instances For
JVP for a single leaf: tangent is nonzero only at x.
Instances For
Scalar-loss JVP for a single leaf.
Instances For
Apply an SGD update to all parameters recorded via use.
grads is expected to be the dense gradient array returned by backwardDenseAll /
backwardScalarDenseAll. Only entries corresponding to parameters (leaves that were produced by
use) are used to update Param.value.
PyTorch comparison: like iterating params and doing p.data -= lr * p.grad.
Instances For
Pure correctness hook: session snapshot ↔ proved IR backprop #
Core proof-link: running the runtime reverse-mode loop on the compiled tape equals proved backprop.
This theorem is the "hook" that lets a session-style API be backed by the proved IR:
compileAuxData produces a tape, and Tape.backwardDenseFrom is shown equal to
GraphData.backpropAllCtx (up to the TList.toAnyArray representation change).
Public re-exports (stable names for docs) #
Public alias for the proof-linked session state (internal definition re-export).
Instances For
Public alias for the proof-linked session object (internal definition re-export).
Instances For
Compute dense gradients for all tracked refs w.r.t. an output tensor and a seed.
This mirrors the "backward with custom seed" pattern in tensor AD systems.
Instances For
Dense gradients for all tracked refs w.r.t. a scalar loss (seed is implicitly 1).
Instances For
Extract the gradient tensor for a specific ref from a dense gradient array.
Instances For
Public proof hook: the runtime reverse-mode loop on the compiled tape equals proved IR backprop.
This is a re-export of the internal theorem so downstream users can cite a stable name.