GraphM #
Proof-compiled graph authoring API.
Proofs.Autograd.Algebra.GraphData is an executable SSA/DAG graph used by the proof-compiled
pipeline (Runtime.Autograd.Compiled). Constructing it directly exposes dependent indices (Idx)
into the graph context and is therefore fairly low-level.
This module defines a small StateT builder (GraphM) that:
- hides
Idxbookkeeping behind typed variables (Var s); - tracks the growing context automatically;
- emits
GraphDatanodes with runtime shape checks for safety.
Reading map #
GraphM.argandGraphM.argsname inputs from the initial context.GraphM.constandGraphM.rand_uniformadd constant and deterministic runtime nodes.- The remaining op builders mirror the
Runtime.Autograd.TorchLean.Backendsurface one-for-one.
Shorthand for the underlying executable SSA graph type from Proofs.Autograd.Algebra.
Instances For
Executable node payload for the proof-compiled SSA graph (GraphData).
Instances For
A typed handle to a value in the growing compiled context.
Var s carries its expected Shape at the type level, while id is the runtime index into the
concatenated context Γ ++ ss.
- id : ℕ
Runtime id of the value inside the concatenated context
Γ ++ ss.The shape index on
Var sis the static guarantee; this numeric id is the executable handle used when constructingIdxproofs forGraphDatanodes.
Instances For
Instances For
GraphM.arg is correct but a little noisy for examples (you must repeat the index and shape).
VarList + args give a small convenience layer: args returns one Var per entry in Γ,
in order, without spelling indices.
First variable in a nonempty VarList.
Instances For
Tail variables in a nonempty VarList.
Instances For
State for the GraphM builder.
It is a sigma pair of:
- the list of intermediate shapes
ssproduced so far, and - the corresponding executable SSA graph payload
GraphData α Δ Γ ss.
Instances For
StateT builder monad for authoring a GraphData program, with explicit environment Δ.
Instances For
Empty builder state (no intermediate nodes yet).
Instances For
Empty builder state for an explicit environment type Δ.
Instances For
Run a GraphM program from an empty state.
Instances For
Build a GraphData by running a GraphM program.
This is the usual entry point: write a do-block that constructs the graph using arg, ops,
and returns Unit; get back the finalized builder state containing ss and the graph.
Instances For
Length of the current context Γ ++ ss (inputs + intermediates).
Instances For
Convert a Var s into a dependent Idx (Γ ++ ss) s.
This performs bounds checking and a runtime shape check, returning a structured error if the variable points outside the current context or has the wrong shape.
Instances For
Append a node to the graph state and return a fresh Var pointing to its output.
The returned variable id is Γ.length + ss.length, i.e. it points at the newly appended entry.
Instances For
Forward-mode JVP availability for a compiled graph builder op.
- implemented : JvpAvailability
The op supplies a real forward-mode JVP rule.
- reverseOnly
(op : String)
: JvpAvailability
The op supplies reverse-mode VJP only. Forward-mode requests fail loudly.
Instances For
Instances For
Instances For
Compiled ops that provide VJP for training but no forward-mode JVP rule.
Keeping the list executable gives callers a stable preflight hook instead of discovering the gap only after a directional-derivative run reaches the node. The list is intentionally empty when all compiled builder ops have concrete JVP rules.
Instances For
Return the JVP status for a named compiled op.
Instances For
Human-readable message for reverse-only compiled ops.
Instances For
Fail-fast marker for compiled nodes whose forward-mode JVP rule is intentionally absent.
Returning a zero tangent here would silently corrupt forward-mode autodiff. Reverse-mode users are
unaffected because these nodes still provide real vjp implementations. Forward-mode callers get a
loud error, and reverseOnlyJvpOps provides a preflight list for tools that want to reject such
graphs before running a JVP.
Instances For
Reference an input variable from the initial context Γ.
This checks that the provided index is within bounds and that the requested shape matches the
shape at that position in Γ.
PyTorch comparison: this is like naming a graph input tensor in a traced graph.
Instances For
Pure helper to build VarList Γ starting at a given id offset.
Instances For
Return one Var per entry of Γ, in order.
This is a convenience wrapper around arg that avoids manually writing indices in examples.
Instances For
Embed a constant tensor as a node in the compiled graph.
This node has no input dependencies (vjp = 0, jvp = 0), i.e. it is treated as a constant
with respect to the graph inputs.
PyTorch comparison: a constant literal captured into a traced/compiled graph.
Instances For
Deterministic U[0,1) tensor generator (seeded, pure).
Instances For
Deterministic {0,1} mask generator (seeded, pure).
Note: for differentiation purposes, this node is treated as a stop-gradient op:
jvp = 0 and vjp = 0 for all inputs (including keepProb). This matches the intended use in
dropout where the probability is a hyperparameter (not differentiated), while keeping execution
deterministic in the .compiled backend.
Instances For
Stop-gradient boundary.
Forward semantics: identity (detach(x) = x).
Backward semantics: no gradient flows to x (treated as constant w.r.t. the graph inputs).
Instances For
JVP vs VJP in this module
Each compiled node stores both:
vjp: reverse-mode vector-Jacobian product (used by backprop), andjvp: forward-mode Jacobian-vector product (directional derivative).
The .compiled runtime path is primarily exercised via reverse-mode (VJP) and compilation to the
eager tape. Basic elementwise/bilinear ops provide real JVP rules, shape-structural ops (for
example slice/concat) apply the same transformation to the tangent, and heavier ops should expose
named spec-layer JVP helpers before being wired here. Reverse-only ops
it must be listed in reverseOnlyJvpOps and call unsupportedJvp rather than returning a silent
zero tangent.
Forward-mode coverage is expanded by adding concrete jvp rules next to the corresponding
forward and vjp definitions.
Elementwise addition node (y = a + b).
PyTorch comparison: torch.add(a, b).
Instances For
Elementwise subtraction node (y = a - b).
PyTorch comparison: torch.sub(a, b).
Instances For
Elementwise multiplication node (y = a ⊙ b).
PyTorch comparison: torch.mul(a, b).
Instances For
Square x ↦ x ⊙ x.
Instances For
Scale a tensor by a scalar constant c (y = c * x).
PyTorch comparison: c * x / torch.mul(x, c).
Instances For
Elementwise absolute value.
PyTorch comparison: torch.abs(x).
Instances For
Elementwise square root.
PyTorch comparison: torch.sqrt(x).
Instances For
Elementwise clamp to [minVal, maxVal].
PyTorch comparison: torch.clamp(x, min=minVal, max=maxVal).
Instances For
Elementwise maximum.
At ties we split the gradient equally (0.5 / 0.5), matching the tie-handling documented in
the eager tape (NN.Runtime.Autograd.Engine.Core).
PyTorch comparison: torch.maximum(a, b).
Instances For
Elementwise minimum.
At ties we split the gradient equally (0.5 / 0.5).
PyTorch comparison: torch.minimum(a, b).
Instances For
Elementwise ReLU.
PyTorch comparison: torch.nn.functional.relu(x).
Instances For
Elementwise sigmoid. PyTorch comparison: torch.sigmoid(x).
Instances For
Elementwise tanh. PyTorch comparison: torch.tanh(x).
Instances For
Softmax along the last axis (recursing over outer dimensions).
PyTorch comparison: torch.softmax(x, dim=-1).
Instances For
Stable log-softmax along the last axis.
This is intentionally a primitive in the compiled graph, not the composition
log ∘ softmax, so proof/IR execution and eager CUDA share the same PyTorch-style numerical
contract.
Instances For
Elementwise softplus. PyTorch comparison: torch.nn.functional.softplus(x).
Instances For
Elementwise exponential. PyTorch comparison: torch.exp(x).
Instances For
Elementwise natural logarithm. PyTorch comparison: torch.log(x).
Instances For
Elementwise reciprocal x ↦ 1/x. PyTorch comparison: torch.reciprocal(x).
Instances For
Elementwise numerically-stable log (uses an internal ε).
PyTorch comparison: commonly written torch.log(x + eps).
Instances For
Reduce-sum over all entries, producing a scalar.
PyTorch comparison: torch.sum(x).
Instances For
Mean-squared error loss with "mean" reduction, producing a scalar.
PyTorch comparison: torch.nn.functional.mse_loss(yhat, target, reduction=\"mean\").
Instances For
Affine layer y = W x + b in the compiled graph.
PyTorch comparison: torch.nn.functional.linear / torch.nn.Linear.
The JVP is the usual product rule:
d(Wx+b) = dW*x + W*dx + db.
Instances For
Matrix multiplication ((m×n) @ (n×p) → (m×p)).
PyTorch comparison: torch.matmul.
The JVP is the bilinear product rule d(A @ B) = dA @ B + A @ dB.
Instances For
Batched matrix multiplication (batch×m×n with batch×n×p).
PyTorch comparison: torch.bmm.
The JVP is the batched bilinear product rule d(A @ B) = dA @ B + A @ dB.
Instances For
Concatenate two vectors (dim-0 concat).
PyTorch comparison: torch.cat([a, b], dim=0) for 1D tensors.
Instances For
Concatenate along the leading dimension (dim=0) for tensors of shape .dim n s.
PyTorch comparison: torch.cat([a, b], dim=0).
Instances For
Slice a contiguous range along dim=0.
PyTorch comparison: x[start : start+len] for tensors where the leading dimension is indexed.
Instances For
N-D max pooling (channels-first) on a single sample tensor (no batch axis).
PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on
the spatial rank d.
Forward-mode status: implemented. The JVP follows the primal argmax selected by
Spec.maxPoolJvpSpec, including the documented first-winner tie convention.
Instances For
N-D average pooling (channels-first) on a single sample tensor (no batch axis).
PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on
the spatial rank d.
Forward-mode status: implemented. Average pooling is linear, so the JVP is the same average-pool map applied to the input tangent.
Instances For
N-D smooth max pooling (log-sum-exp surrogate) on a single sample tensor (no batch axis).
PyTorch comparison: there is no direct primitive; this is a differentiable approximation to max pooling.
Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.
Instances For
2D max-pooling (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.max_pool2d (without a batch dimension).
Forward-mode status: implemented. The JVP routes each output tangent through the argmax selected by the primal input.
Instances For
2D max-pooling with explicit padding.
PyTorch comparison: torch.nn.functional.max_pool2d with padding.
Forward-mode status: implemented. Padding is fixed and the JVP follows the real primal winner, ignoring padded cells just like the forward pass.
Instances For
Smooth (soft) max-pooling, controlled by beta.
This is a differentiable approximation to max-pooling.
Forward-mode status: implemented. The JVP is the softmax-weighted tangent of the log-sum-exp pooling window.
Instances For
Average pooling (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.avg_pool2d (without a batch dimension).
Forward-mode status: implemented. Average pooling is linear, so the JVP is average pooling of the input tangent.
Instances For
Average pooling with explicit padding.
PyTorch comparison: torch.nn.functional.avg_pool2d with padding.
Forward-mode status: implemented. Padding is fixed and average pooling is linear, so the JVP is the padded average-pool map applied to the input tangent.
Instances For
Flatten a tensor to a 1D vector (preserving total size).
PyTorch comparison: torch.flatten(x) (for a single tensor value).
Instances For
Reshape a tensor, given a proof that the total sizes match.
PyTorch comparison: torch.reshape(x, new_shape).
Instances For
Transpose a 2D matrix. PyTorch comparison: x.transpose(0, 1) / x.T for matrices.
Instances For
Transpose a rank-3 tensor by moving the first axis to the last ((a,b,c) → (b,c,a)).
PyTorch comparison: x.permute(1, 2, 0).
Instances For
Transpose a rank-3 tensor by moving the last axis to the first ((a,b,c) → (c,a,b)).
PyTorch comparison: x.permute(2, 0, 1).
Instances For
Swap the last two axes of a rank-3 tensor ((a,b,c) → (a,c,b)).
PyTorch comparison: x.transpose(1, 2) for a 3D tensor.
Instances For
Swap two adjacent axes at a given nesting depth.
This is the compiled-graph analogue of the eager Tape.swapAdjacentAtDepth.
PyTorch comparison: a permute that swaps two neighboring dimensions.
Instances For
Broadcast x : s₁ to a larger shape s₂ (given a CanBroadcastTo witness).
PyTorch comparison: x.expand(...) / broadcasting semantics in elementwise ops.
Instances For
Reduce-sum along a given axis.
PyTorch comparison: torch.sum(x, dim=axis).
Instances For
Reduce-mean along a given axis.
PyTorch comparison: torch.mean(x, dim=axis).
Instances For
Gather a single scalar from a vector at a known-in-bounds index.
PyTorch comparison: x[i] for a 1D tensor.
Instances For
Gather a row from a matrix at a known-in-bounds row index.
PyTorch comparison: x[i, :] for a 2D tensor.
Instances For
Gather a scalar from a vector at a runtime Nat index.
If i is out of bounds we return 0 and propagate no gradient (matching the forward choice).
Instances For
Gather a vector of length k from a length-n vector using an index tensor of Nats.
Out-of-bounds indices yield 0 at the corresponding output position.
PyTorch comparison: torch.gather for 1D inputs, with explicit bounds handling.
Instances For
Gather k rows from a (rows×cols) matrix using an index vector of Nats.
Out-of-bounds indices yield a zero row.
PyTorch comparison: torch.index_select(x, dim=0, index=idx) with explicit bounds handling.
Instances For
Scatter-add into a vector at a single in-bounds index.
scatter_add_vec x v i adds the scalar v into x[i].
PyTorch comparison: x.index_add_(dim=0, index=[i], source=[v]) (conceptually).
Instances For
Scatter-add into a matrix at a single in-bounds row index.
scatter_add_row x v i adds the row vector v into x[i, :].
PyTorch comparison: x.index_add_(dim=0, index=[i], source=v.unsqueeze(0)) (conceptually).
Instances For
Layer normalization (sequence-first), producing the same shape as the input.
PyTorch comparison: torch.nn.LayerNorm / torch.nn.functional.layer_norm (modulo exact layout).
Forward-mode status: implemented by Spec.layerNormJvp, including parameter tangents for
gamma and beta.
Instances For
Batch normalization in channel-first layout (no running statistics; spec-level functional form).
PyTorch comparison: torch.nn.BatchNorm2d in NCHW layout (modulo exact semantics/parameters).
Forward-mode status: implemented by Spec.batchNorm2dJvp, including parameter tangents for
gamma and beta.
Instances For
Multi-head attention primitive (shape-specialized).
PyTorch comparison: torch.nn.MultiheadAttention / scaled dot-product attention.
Forward-mode status: implemented by Spec.MultiHeadAttentionJvp, including tangents for the
input and all four projection matrices.
Instances For
N-D convolution (channels-first) on a single sample tensor (no batch axis).
Conventions:
- input shape is
(inC, spatial...), - kernel shape is
(outC, inC, kernelSpatial...), - bias shape is
(outC), - output spatial sizes use the usual PyTorch-style formula (floor division).
PyTorch comparison: torch.nn.functional.conv{d}d, specialized to a single sample.
Forward-mode JVP uses bilinearity:
d(conv(k,b,x)) = conv(k,0,dx) + conv(dk,db,x).
Instances For
N-D transpose convolution (channels-first) on a single sample tensor (no batch axis).
Conventions:
- input shape is
(inC, spatial...), - kernel shape is
(inC, outC, kernelSpatial...)(PyTorch layout), - bias shape is
(outC), - output spatial sizes use:
out[a] = (in[a] - 1) * stride[a] - 2*padding[a] + kernel[a](withoutput_padding = 0).
PyTorch comparison: torch.nn.functional.conv_transpose{d}d, specialized to a single sample.
Forward-mode JVP uses bilinearity:
d(convTranspose(k,b,x)) = convTranspose(k,0,dx) + convTranspose(dk,db,x).
Instances For
2D convolution (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.conv2d (without a batch dimension).
Forward-mode JVP uses bilinearity:
d(conv2d(k,b,x)) = conv2d(k,0,dx) + conv2d(dk,db,x).
Instances For
2D transpose convolution (channel-first) on a single image tensor.
PyTorch comparison: torch.nn.functional.conv_transpose2d (without a batch dimension).
Forward-mode JVP uses bilinearity:
d(convTranspose2d(k,b,x)) = convTranspose2d(k,0,dx) + convTranspose2d(dk,db,x).