TorchLean API

NN.Spec.Core.Shape

Shapes (Spec.Shape) #

Shape is the type-level “shape descriptor” for tensors in the spec layer.

TorchLean uses shape-indexed tensors:

Tensor α s

so Shape is how we encode the structure of s in a way Lean can use for both computation and proofs.

Representation #

Shape is an inductive tree:

This matches the tensor definition in NN/Spec/Core/Tensor/Core.lean.

Common utilities #

PyTorch analogy:

Broadcasting and axes #

Broadcasting is encoded via CanBroadcastTo / BroadcastTo.

This is an intentionally asymmetric relation ("broadcast s1 to s2"), because most tensor code is naturally written by choosing the output shape and requiring each input to broadcast to it.

The typeclass wrapper BroadcastTo keeps higher-level specs readable: in many cases Lean can infer the broadcast evidence automatically, so call sites don’t have to manually thread proofs around.

It also defines axis-validity helpers (valid_axis) and a well_formed predicate for “all dimensions are positive”, which is useful when you want to rule out degenerate cases in proofs.

We represent shapes as an inductive tree instead of a bare List Nat because:

inductive Spec.Shape :

Tensor shape descriptor used to index spec-level tensors (Spec.Tensor α s).

Shape is an outermost-first tree:

  • .scalar for a scalar,
  • .dim n s for a length-n dimension whose entries have shape s.
Instances For
    def Spec.instDecidableEqShape.decEq (x✝ x✝¹ : Shape) :
    Decidable (x✝ = x✝¹)
    Instances For
      @[implicit_reducible]
      @[implicit_reducible]

      Build a shape from a list of dimensions (outermost first).

      Instances For

        Internal helper: check that a list of axis indices is duplicate-free.

        Instances For
          def Spec.Shape.getDim! (xs : List ) (i : ) :

          Internal helper: get the i-th entry (0-based) from a list of dimensions, defaulting to 0.

          Instances For

            Pretty-print a Shape for debugging / logs.

            Instances For

              Swap two adjacent dimensions at a given depth (0‑based from the outermost).

              Instances For

                Swapping adjacent dims at depth depth twice returns the original shape.

                Append a new innermost dimension.

                Instances For

                  Total number of scalar elements (a.k.a. “numel”).

                  Instances For
                    theorem Spec.Shape.size_dim_mul (a b : ) (s : Shape) :
                    (dim a (dim b s)).size = a * b * s.size

                    size for a 2D shape factors as a * b * size s.

                    theorem Spec.Shape.size_appendDim (s : Shape) (n : ) :
                    (s.appendDim n).size = s.size * n

                    appendDim multiplies the number of scalar elements by the appended dimension.

                    This lemma is the standard justification for reshape tricks where we:

                    • treat a tensor of shape s.appendDim n as a matrix of shape (size s) × n, or
                    • append an extra singleton dimension (n = 1) without changing size.
                    theorem Spec.Shape.size_eq_of_dModel_eq_numHeads_mul_headDim (seqLen numHeads dModel headDim : ) (h : dModel = numHeads * headDim) :
                    (dim seqLen (dim dModel scalar)).size = (dim numHeads (dim seqLen (dim headDim scalar))).size

                    Shape-size identity used in Transformer attention reshapes.

                    If dModel = numHeads * headDim, then: (seqLen × dModel) has the same size as (numHeads × seqLen × headDim).

                    Size of the outermost dimension (or 1 for scalar).

                    Instances For

                      Size of the innermost dimension (or 1 for scalar).

                      Instances For

                        Convert to a list of dimensions (outermost first).

                        Instances For
                          @[simp]

                          ofList is a left inverse of toList.

                          @[simp]

                          toList is a right inverse of ofList.

                          Convert to an array of dimensions (outermost first).

                          Instances For

                            Boolean equality test for shapes (structural).

                            Instances For
                              @[implicit_reducible]

                              BEq Shape uses the explicit structural boolean test Shape.areEqual.

                              @[implicit_reducible]

                              Default shape is scalar.

                              Check if shape is a matrix (m × n).

                              Instances For

                                Check if shape is a vector (n).

                                Instances For

                                  Check if shape is scalar.

                                  Instances For

                                    Get dimension at index i (0‑based), or none if out of bounds.

                                    Instances For

                                      Typeclass-friendly broadcasting (BroadcastTo) #

                                      The CanBroadcastTo relation is asymmetric (“broadcast s₁ to s₂”), matching how most operations are written: we pick a target shape and require each operand to broadcast to it.

                                      The BroadcastTo wrapper lets Lean search for a broadcast proof automatically, which is convenient for higher-level specs (layers/models) where the broadcasting details are not the point.

                                      PyTorch analogy:

                                      Evidence that shape s₁ can be broadcast to shape s₂ (PyTorch-style broadcasting).

                                      Instances For
                                        Instances For
                                          @[implicit_reducible]
                                          instance Spec.Shape.instReprCanBroadcastTo {a✝ a✝¹ : Shape} :
                                          Repr (a✝.CanBroadcastTo a✝¹)
                                          class Spec.Shape.BroadcastTo (s₁ s₂ : Shape) :

                                          Typeclass wrapper for CanBroadcastTo so broadcast proofs can be inferred.

                                          Instances
                                            @[implicit_reducible]

                                            Scalar broadcasts to any shape (analogue of "prepend 1s and expand").

                                            @[implicit_reducible]
                                            instance Spec.Shape.broadcastToDimEq {n : } {s₁ s₂ : Shape} [bc : s₁.BroadcastTo s₂] :
                                            (dim n s₁).BroadcastTo (dim n s₂)

                                            Broadcasting preserves equal leading dimensions when the tails broadcast.

                                            @[implicit_reducible]
                                            instance Spec.Shape.broadcastToDim1ToN {n : } {s₁ s₂ : Shape} [bc : s₁.BroadcastTo s₂] :
                                            (dim 1 s₁).BroadcastTo (dim n s₂)

                                            Dimension 1 can broadcast to any n (PyTorch's main broadcast rule).

                                            @[implicit_reducible]
                                            instance Spec.Shape.broadcastToExpandDims {n : } {s₁ s₂ : Shape} [bc : s₁.BroadcastTo s₂] :
                                            s₁.BroadcastTo (dim n s₂)

                                            Prepend an outer dimension (the "expand_dims" step used to align ranks).

                                            true iff two shapes have the same number of elements.

                                            Instances For

                                              Rank = number of dimensions (scalar has rank 0).

                                              Instances For

                                                Friendly aliases (PyTorch-style) #

                                                We keep the canonical names (toList, rank, size, well_formed) because they show up throughout the spec/proof code.

                                                For docs and examples, these aliases read more like PyTorch.

                                                @[reducible, inline]

                                                PyTorch-style name for Shape.toList.

                                                Instances For
                                                  @[reducible, inline]
                                                  abbrev Spec.Shape.ndim (s : Shape) :

                                                  PyTorch-style name for Shape.rank.

                                                  Instances For
                                                    @[reducible, inline]
                                                    abbrev Spec.Shape.numel (s : Shape) :

                                                    PyTorch-style name for Shape.size ("numel").

                                                    Instances For

                                                      Permute axes of a shape using a runtime permutation list (0-based). Returns none if invalid.

                                                      Instances For

                                                        Axis utilities #

                                                        Why these exist:

                                                        So we provide:

                                                        PyTorch differences:

                                                        Evidence that reducing along axis is well-defined for a shape.

                                                        This is a small helper predicate used to rule out degenerate 0-length dimensions when stating laws about reductions.

                                                        Instances For
                                                          @[simp]
                                                          theorem Spec.Shape.reducibleAlong_head {n : } {s : Shape} :
                                                          reducibleAlong 0 (dim (n + 1) s)

                                                          simp lemma: axis 0 is reducible for any positive outer dimension.

                                                          @[simp]
                                                          theorem Spec.Shape.reducibleAlong_tail {n : } {s : Shape} {k : } (h : reducibleAlong k s) :
                                                          reducibleAlong (k + 1) (dim (n + 1) s)

                                                          simp lemma: reducibility for inner axis lifts to the next outer axis.

                                                          valid_axis axis s means that axis is a valid reduction axis for s.

                                                          We use a Prop + typeclass wrapper (valid_axis_inst) so proofs can be synthesized by typeclass resolution in downstream code.

                                                          inductive Spec.Shape.valid_axis :
                                                          ShapeProp

                                                          Axis validity predicate for reduction ops (0-based axis in Nat).

                                                          Instances For

                                                            Typeclass wrapper for valid_axis so common axis proofs can be inferred.

                                                            Instances
                                                              instance Spec.Shape.validAxisInstZero {n : } {s : Shape} :
                                                              valid_axis_inst 0 (dim (n + 1) s)

                                                              Instance: axis 0 is valid for any positive outer dimension.

                                                              instance Spec.Shape.validAxisInstZeroAlt {n : } {s : Shape} (h : n 0) :

                                                              Instance: axis 0 is valid for a nonzero outer dimension n.

                                                              This is a convenience wrapper that turns n ≠ 0 into the n+1 form expected by valid_axis.

                                                              instance Spec.Shape.validAxisInstOne {n1 n2 : } {s : Shape} (h₁ : n1 0) (h₂ : n2 0) :
                                                              valid_axis_inst 1 (dim n1 (dim n2 s))

                                                              Instance: axis 1 is valid for a 2D shape when both outer dims are nonzero.

                                                              instance Spec.Shape.validAxisInstSucc {n : } {s : Shape} {k : } [inst : valid_axis_inst k s] :
                                                              valid_axis_inst (k + 1) (dim (n + 1) s)

                                                              Instance: if k is a valid axis for s, then k+1 is a valid axis for .dim (n+1) s.

                                                              instance Spec.Shape.validAxisInstZeroAlt2 {n : } {s : Shape} (h : n > 0) :

                                                              Instance: axis 0 is valid if you have a positivity proof n > 0 (converted to n ≠ 0).

                                                              theorem Spec.Shape.gt_pos_to_ne_zero {n : } (h : n > 0) :
                                                              n 0

                                                              Helper lemma: a positive natural is not zero.

                                                              Well-formedness (well_formed) #

                                                              well_formed s means "all dimensions are positive".

                                                              Why this matters (and why we designed it this way):

                                                              This is a pragmatic "make the common case pleasant" choice: proofs and specs are shorter, and runtime checks can still handle edge cases separately.

                                                              well_formed s means "all dimensions of s are positive" (recursively).

                                                              Instances For

                                                                Size positivity #

                                                                If all dimensions of a shape are positive, then the total number of scalar elements is positive.

                                                                This is a small but useful bridge lemma: many reductions are only defined for nonempty dimensions, and WellFormed is our standard way of expressing that assumption.

                                                                If s.well_formed, then Shape.size s > 0.

                                                                instance Spec.Shape.validAxisLastInst {s : Shape} (h : s.rank > 0) (hw : s.wellFormed) :

                                                                If rank s > 0 and s is well-formed, then the last axis rank s - 1 is valid.

                                                                This powers many "reduce over last dimension" specs where the axis is computed as rank s - 1.

                                                                Typeclass wrapper for Shape.well_formed.

                                                                We use a typeclass (instead of passing a well_formed proof everywhere) because it mirrors how other "side conditions" are handled in the library: call sites stay clean, and instances can be provided locally (e.g. letI : Shape.WellFormed s := ...) when needed.

                                                                Instances

                                                                  Scalars are always well-formed.

                                                                  instance Spec.Shape.instWellFormedDimOfGtNatOfNat {n : } {s : Shape} [s.WellFormed] (h : n > 0) :

                                                                  If s is well-formed and n > 0, then .dim n s is well-formed.

                                                                  Convenience instance: .dim 1 s is well-formed when s is.

                                                                  Convenience instance: .dim 2 s is well-formed when s is.

                                                                  If a Fact (n > 0) is in scope, lift it to a Shape.WellFormed (.dim n s) instance.

                                                                  validAxisLastAuto is a convenience instance for the most common reduction axis: "reduce over the last dimension".

                                                                  In PyTorch this is dim=-1 (after normalization). Here we stay in Nat, so the last axis is rank s - 1, and we require rank s > 0 plus well-formedness so the proof is meaningful.

                                                                  instance Spec.Shape.validAxisLastAuto {s : Shape} [h_wf : s.WellFormed] (h : s.rank > 0) :

                                                                  Convenience instance: infer valid_axis_inst (rank s - 1) s from WellFormed s and rank s > 0.

                                                                  Bridge lemma: turn a valid_axis proof into a reducibleAlong proof.

                                                                  Why both exist:

                                                                  This function is the adapter between the two views.

                                                                  def Spec.Shape.proveReducibleAlong (axis : ) (s : Shape) (h : valid_axis axis s) :

                                                                  Convert a valid_axis proof into a structurally convenient reducibleAlong proof.

                                                                  Instances For

                                                                    padLeft n s prepends n singleton dimensions to a shape.

                                                                    PyTorch analogy: unsqueeze(0) repeated n times (or equivalently viewing a tensor as having extra leading dimensions of size 1). This is also the "prepend 1s" step you see in broadcasting.

                                                                    Prepend n leading singleton dimensions (size 1) to a shape.

                                                                    Instances For
                                                                      theorem Spec.Shape.padLeft_rank (n : ) (s : Shape) :
                                                                      (padLeft n s).rank = n + s.rank

                                                                      padLeft n s increases the rank by exactly n.