Operation Contracts #
Shared operation contracts for NN.IR.Graph.
Several IR passes need to agree on the same small set of “shape contracts”:
NN.IR.Infer: recompute output shapes from op parameters + parent shapes.NN.IR.Check: expose the documentedGraph.checkShapeswrapper.NN.IR.Semantics: evaluate nodes and reject ill-shaped graphs with readable error messages.
The point of this file is to keep shape arithmetic out of individual passes. If an op has nontrivial shape behavior (concat, matmul, pooling, convolution, LayerNorm flattening, axis moves), define the contract here first and call it from inference/semantics instead of copying the formula.
Small shape utilities #
These helpers are used by multiple IR passes, especially Infer and Semantics.
The output shape of flattening a tensor of shape s to a 1D vector.
Instances For
If s has rank ≥ 2, return the shape obtained by swapping its first two axes.
Example: (a, b, rest) becomes (b, a, rest).
Instances For
If s has shape (a, b, c) (rank=3 with scalar base), return (a, c, b).
This is the common “transpose the last two axes” pattern for batched matrices.
Instances For
Generic contract helpers #
These functions live outside any particular pass (Infer/Check/Semantics) so they can be
reused without introducing import cycles.
Check that an axis is in-bounds for a given shape.
Instances For
Compute the (seqLen, embedDim) pair used to interpret layernorm axis.
TorchLean’s IR stores LayerNorm as an axis : Nat instead of a full normalized_shape tuple.
We interpret this in the same way the PyTorch exporter does:
normalized_shape = dims.drop axis
That is, we normalize over the suffix of dimensions starting at axis. To reuse the current
spec primitive (Spec.layerNorm), we flatten the input shape s into a 2D view:
seqLen= product of dimensions beforeaxis(dims.take axis)embedDim= product of dimensions fromaxisonward (dims.drop axis)
Then we run 2D last-axis LayerNorm on a (seqLen × embedDim) tensor and reshape back.
Instances For
Check that axis refers to the last axis of s.
This is a convenience predicate for passes/backends that restrict an op to last-axis behavior.
For example, some verification bounds are implemented only for last-axis softmax/layernorm and
use this check to fail fast with a readable error.
Instances For
Permutation (0-based axes) that moves axis to the last position, preserving the relative
order of the other axes.
Example: rank=4 and axis=1 yields [0,2,3,1].
Instances For
Permutation (0-based axes) that moves axis to the first position, preserving the relative
order of the other axes.
Example: rank=4 and axis=2 yields [2,0,1,3].
Instances For
Infer the output shape for matmul from the two parent shapes.
Supported cases:
- 2D:
(m×n) · (n×p) → (m×p) - limited 3D “batched matmul”:
(b×m×n) · (b×n×p) → (b×m×p)
Instances For
Infer the output shape for concat from the parent shapes.
All parents must:
- have the same rank,
- agree on every dimension except
axis, and - have
axisin bounds.
The output shape matches the parents except at axis, where the dimension is the sum of the input
dimensions.
PyTorch analogy: torch.cat(xs, dim=axis) for a list xs of tensors.
Instances For
Instances For
Pooling/Conv2D shape arithmetic (CHW-only) #
These formulas mirror the spec/runtime conventions (CHW tensors, no dilation, symmetric padding). Centralizing them gives inference, evaluation, verification, and export code a shared convention for convolution and pooling shapes.
Output length for a 1D sliding-window op without padding: ⌊(in - k)/stride⌋ + 1.
Instances For
Output length for a 1D sliding-window op with symmetric padding: ⌊(in + 2*pad - k)/stride⌋ + 1.
Instances For
Output shape for CHW pooling without padding.
Instances For
Output shape for CHW pooling with symmetric padding.
Instances For
Output shape for CHW conv2d (single-image, no batch dim).
Instances For
Infer the output shape for CHW pooling without padding, from a parent shape.
Instances For
Infer the output shape for CHW pooling with padding, from a parent shape.
Instances For
Infer the output shape for CHW Conv2D, checking the declared inC against the parent shape.