Shape Inference #
Shape inference and consistency checking for NN.IR.Graph.
NN.IR.Node stores an outShape field because many consumers want shape metadata to be available
without re-running inference (pretty printers, exporters, verifiers, etc.).
This module provides an independent shape inference/checking procedure that recomputes the expected output shape of each node from:
- the node's
OpKindpayload (when present), and - the parent nodes' output shapes.
For parameterized ops whose output shape depends on external parameters (notably OpKind.linear),
we treat the node's declared outShape as an input to the checker and validate the local contracts
we can check (e.g. input/output are vectors).
This is the single source of truth for Graph.checkShapes: adding a new OpKind should extend
this match first, then the semantics/export/verification passes can rely on the same contract.
PyTorch analogy:
inferNodeOutShapecorresponds to shape propagation used when validating an FX graph.- Where the true output shape depends on parameters, this module performs contract checking rather than attempting to read those parameters.
References / related systems:
- PyTorch FX (graph representation): https://pytorch.org/docs/stable/fx.html
- ONNX shape inference: https://onnx.ai/onnx/shape_inference.html
Node-local inference #
Most IR ops are “shape transparent” (elementwise, permute, etc.). A few need special handling:
matmulhas rank-sensitive rules (2D and a limited 3D batched case),concatneeds to merge multiple parents along an axis,- pooling/conv ops use centralized CHW arithmetic from
OpContracts.
Infer the output shape of a node from its kind + parent shapes.
This function is used by Graph.checkInferredShapes below.
Instances For
Infer shapes for every node (in topo/id order) and check that Node.outShape matches.
This is meant as a compiler/back-end sanity check and as a clean IR invariant for the docs: well-formed graphs have self-consistent declared shapes.