TorchLean API

NN.IR.OpContracts

Operation Contracts #

Shared operation contracts for NN.IR.Graph.

Several IR passes need to agree on the same small set of “shape contracts”:

The point of this file is to keep shape arithmetic out of individual passes. If an op has nontrivial shape behavior (concat, matmul, pooling, convolution, LayerNorm flattening, axis moves), define the contract here first and call it from inference/semantics instead of copying the formula.

Small shape utilities #

These helpers are used by multiple IR passes, especially Infer and Semantics.

The output shape of flattening a tensor of shape s to a 1D vector.

Instances For

    If s has rank ≥ 2, return the shape obtained by swapping its first two axes.

    Example: (a, b, rest) becomes (b, a, rest).

    Instances For

      If s has shape (a, b, c) (rank=3 with scalar base), return (a, c, b).

      This is the common “transpose the last two axes” pattern for batched matrices.

      Instances For

        Generic contract helpers #

        These functions live outside any particular pass (Infer/Check/Semantics) so they can be reused without introducing import cycles.

        Check that an axis is in-bounds for a given shape.

        Instances For

          Check that a natural-number op parameter is nonzero.

          Instances For

            Compute the (seqLen, embedDim) pair used to interpret layernorm axis.

            TorchLean’s IR stores LayerNorm as an axis : Nat instead of a full normalized_shape tuple. We interpret this in the same way the PyTorch exporter does:

            normalized_shape = dims.drop axis

            That is, we normalize over the suffix of dimensions starting at axis. To reuse the current spec primitive (Spec.layerNorm), we flatten the input shape s into a 2D view:

            • seqLen = product of dimensions before axis (dims.take axis)
            • embedDim = product of dimensions from axis onward (dims.drop axis)

            Then we run 2D last-axis LayerNorm on a (seqLen × embedDim) tensor and reshape back.

            Instances For

              Check that axis refers to the last axis of s.

              This is a convenience predicate for passes/backends that restrict an op to last-axis behavior. For example, some verification bounds are implemented only for last-axis softmax/layernorm and use this check to fail fast with a readable error.

              Instances For

                Compute the inverse of a permutation list.

                If perm is a permutation of [0,1,...,r-1] (where r = perm.length), then the inverse inv satisfies inv[perm[i]] = i.

                Instances For
                  def NN.IR.OpContracts.inversePerm.setOnce (perm : List ) (r : ) (xs : List (Option )) (axis j val : ) :
                  Instances For

                    Permutation (0-based axes) that moves axis to the last position, preserving the relative order of the other axes.

                    Example: rank=4 and axis=1 yields [0,2,3,1].

                    Instances For

                      Permutation (0-based axes) that moves axis to the first position, preserving the relative order of the other axes.

                      Example: rank=4 and axis=2 yields [2,0,1,3].

                      Instances For

                        Infer the output shape for matmul from the two parent shapes.

                        Supported cases:

                        • 2D: (m×n) · (n×p) → (m×p)
                        • limited 3D “batched matmul”: (b×m×n) · (b×n×p) → (b×m×p)
                        Instances For

                          Infer the output shape for concat from the parent shapes.

                          All parents must:

                          • have the same rank,
                          • agree on every dimension except axis, and
                          • have axis in bounds.

                          The output shape matches the parents except at axis, where the dimension is the sum of the input dimensions.

                          PyTorch analogy: torch.cat(xs, dim=axis) for a list xs of tensors.

                          Instances For

                            Pooling/Conv2D shape arithmetic (CHW-only) #

                            These formulas mirror the spec/runtime conventions (CHW tensors, no dilation, symmetric padding). Centralizing them gives inference, evaluation, verification, and export code a shared convention for convolution and pooling shapes.

                            def NN.IR.OpContracts.slideOut (inLen k stride : ) :

                            Output length for a 1D sliding-window op without padding: ⌊(in - k)/stride⌋ + 1.

                            Instances For
                              def NN.IR.OpContracts.slideOutPad (inLen k stride padding : ) :

                              Output length for a 1D sliding-window op with symmetric padding: ⌊(in + 2*pad - k)/stride⌋ + 1.

                              Instances For
                                def NN.IR.OpContracts.pool2dCHWOutShape (c inH inW kH kW stride : ) :

                                Output shape for CHW pooling without padding.

                                Instances For
                                  def NN.IR.OpContracts.pool2dCHWOutShapePad (c inH inW kH kW stride padding : ) :

                                  Output shape for CHW pooling with symmetric padding.

                                  Instances For
                                    def NN.IR.OpContracts.conv2dCHWOutShape (outC inH inW kH kW stride padding : ) :

                                    Output shape for CHW conv2d (single-image, no batch dim).

                                    Instances For

                                      Infer the output shape for CHW pooling without padding, from a parent shape.

                                      Instances For

                                        Infer the output shape for CHW pooling with padding, from a parent shape.

                                        Instances For
                                          def NN.IR.OpContracts.inferConv2dCHWOutShape (inC outC kH kW stride padding : ) (parent : Spec.Shape) :

                                          Infer the output shape for CHW Conv2D, checking the declared inC against the parent shape.

                                          Instances For