Shapes (Spec.Shape) #
Shape is the type-level “shape descriptor” for tensors in the spec layer.
TorchLean uses shape-indexed tensors:
Tensor α s
so Shape is how we encode the structure of s in a way Lean can use for both computation and
proofs.
Representation #
Shape is an inductive tree:
This matches the tensor definition in NN/Spec/Core/Tensor/Core.lean.
Common utilities #
Shape.size : Shape → Natis the total number of scalar elements (“numel”).Shape.toList : Shape → List Natis a convenient runtime view used by front-ends and bridges.
PyTorch analogy:
Shape.toList scorresponds totensor.shape(a tuple of dimensions).Shape.rank scorresponds totensor.ndim.Shape.size scorresponds totensor.numel().
Broadcasting and axes #
Broadcasting is encoded via CanBroadcastTo / BroadcastTo.
This is an intentionally asymmetric relation ("broadcast s1 to s2"), because most tensor code
is naturally written by choosing the output shape and requiring each input to broadcast to it.
The typeclass wrapper BroadcastTo keeps higher-level specs readable: in many cases Lean can infer
the broadcast evidence automatically, so call sites don’t have to manually thread proofs around.
It also defines axis-validity helpers (valid_axis) and a well_formed predicate for “all
dimensions are positive”, which is useful when you want to rule out degenerate cases in proofs.
We represent shapes as an inductive tree instead of a bare List Nat because:
- it matches the tensor representation (
Tensor α s) structurally, so many definitions are simple structural recursion, - it keeps "scalar vs dim" cases explicit (important for proofs),
- it gives definitional equalities that are friendlier than lists in many places.
Build a shape from a list of dimensions (outermost first).
Instances For
Internal helper: check that a list of axis indices is duplicate-free.
Instances For
Pretty-print a Shape for debugging / logs.
Instances For
Swap two adjacent dimensions at a given depth (0‑based from the outermost).
Instances For
Swapping adjacent dims at depth depth twice returns the original shape.
Append a new innermost dimension.
Instances For
appendDim multiplies the number of scalar elements by the appended dimension.
This lemma is the standard justification for reshape tricks where we:
Shape-size identity used in Transformer attention reshapes.
If dModel = numHeads * headDim, then:
(seqLen × dModel) has the same size as (numHeads × seqLen × headDim).
Size of the outermost dimension (or 1 for scalar).
Instances For
Size of the innermost dimension (or 1 for scalar).
Instances For
Convert to a list of dimensions (outermost first).
Instances For
Convert to an array of dimensions (outermost first).
Instances For
Boolean equality test for shapes (structural).
Instances For
BEq Shape uses the explicit structural boolean test Shape.areEqual.
Default shape is scalar.
Typeclass-friendly broadcasting (BroadcastTo) #
The CanBroadcastTo relation is asymmetric (“broadcast s₁ to s₂”), matching how most
operations are written: we pick a target shape and require each operand to broadcast to it.
The BroadcastTo wrapper lets Lean search for a broadcast proof automatically, which is convenient
for higher-level specs (layers/models) where the broadcasting details are not the point.
PyTorch analogy:
- PyTorch broadcasting aligns shapes from the trailing dimensions by implicitly prepending
1s to the shorter shape. - Our
Shapeis an outermost-first tree, so the corresponding operation isexpand_dims: it inserts leading/outer dimensions to reach the target rank (this is the "prepend1s" step). dim_1_to_ncorresponds to PyTorch's "dimension 1 can expand to n" rule.
Evidence that shape s₁ can be broadcast to shape s₂ (PyTorch-style broadcasting).
- scalar_to_any (s : Shape) : scalar.CanBroadcastTo s
- dim_eq {n : ℕ} {s₁ s₂ : Shape} (tail : s₁.CanBroadcastTo s₂) : (dim n s₁).CanBroadcastTo (dim n s₂)
- dim_1_to_n {n : ℕ} {s₁ s₂ : Shape} (tail : s₁.CanBroadcastTo s₂) : (dim 1 s₁).CanBroadcastTo (dim n s₂)
- expand_dims {n : ℕ} {s₁ s₂ : Shape} (tail : s₁.CanBroadcastTo s₂) : s₁.CanBroadcastTo (dim n s₂)
Instances For
Instances For
Typeclass wrapper for CanBroadcastTo so broadcast proofs can be inferred.
- proof : s₁.CanBroadcastTo s₂
Instances
Scalar broadcasts to any shape (analogue of "prepend 1s and expand").
Broadcasting preserves equal leading dimensions when the tails broadcast.
Dimension 1 can broadcast to any n (PyTorch's main broadcast rule).
Prepend an outer dimension (the "expand_dims" step used to align ranks).
true iff two shapes have the same number of elements.
Instances For
Friendly aliases (PyTorch-style) #
We keep the canonical names (toList, rank, size, well_formed) because they show up
throughout the spec/proof code.
For docs and examples, these aliases read more like PyTorch.
Axis utilities #
Why these exist:
- Reduction ops (
reduce_sum,reduce_mean, etc.) need an axis argument. - In executable code we want to reject invalid axes early, but in spec/proof code we want the axis validity to be available as evidence that can be carried through lemmas.
So we provide:
valid_axis axis s : Propas the core definition, andvalid_axis_inst axis sas a typeclass wrapper so the common cases can be inferred.
PyTorch differences:
- PyTorch allows negative axes (e.g.
dim=-1); here we useNataxes only (0-based). A typical translation is: "last axis" =Shape.rank s - 1(whenrank s > 0).
Evidence that reducing along axis is well-defined for a shape.
This is a small helper predicate used to rule out degenerate 0-length dimensions when stating
laws about reductions.
- head {n : ℕ} {s : Shape} : reducibleAlong 0 (dim (n + 1) s)
- tail {n : ℕ} {s : Shape} {k : ℕ} : reducibleAlong k s → reducibleAlong (k + 1) (dim (n + 1) s)
Instances For
simp lemma: axis 0 is reducible for any positive outer dimension.
simp lemma: reducibility for inner axis lifts to the next outer axis.
valid_axis axis s means that axis is a valid reduction axis for s.
We use a Prop + typeclass wrapper (valid_axis_inst) so proofs can be synthesized by typeclass
resolution in downstream code.
Axis validity predicate for reduction ops (0-based axis in Nat).
- valid_zero {n : ℕ} {s : Shape} : valid_axis 0 (dim (n + 1) s)
- valid_succ {n : ℕ} {s : Shape} {k : ℕ} (h : valid_axis k s) : valid_axis (k + 1) (dim (n + 1) s)
Instances For
Typeclass wrapper for valid_axis so common axis proofs can be inferred.
- proof : valid_axis axis s
Instances
Instance: axis 0 is valid for any positive outer dimension.
Instance: axis 0 is valid for a nonzero outer dimension n.
This is a convenience wrapper that turns n ≠ 0 into the n+1 form expected by valid_axis.
Instance: axis 1 is valid for a 2D shape when both outer dims are nonzero.
Instance: if k is a valid axis for s, then k+1 is a valid axis for .dim (n+1) s.
Instance: axis 0 is valid if you have a positivity proof n > 0 (converted to n ≠ 0).
Helper lemma: a positive natural is not zero.
Well-formedness (well_formed) #
well_formed s means "all dimensions are positive".
Why this matters (and why we designed it this way):
- Many definitions use
Fin nindexing; ifn = 0, there is no index and you end up with either vacuous truths or extra cases that obscure the intent of the lemma. - Some common ops become awkward or partial at
n = 0. For example, a mean typically divides by the number of elements, son = 0needs special-case semantics. - PyTorch does allow zero-sized dimensions, and most ops define a sensible result for them. We intentionally keep that complexity out of the core spec layer because it makes proofs much more case-heavy. When we need zero-dimension tensors, we introduce them with explicit semantics instead of relying on incidental behavior.
This is a pragmatic "make the common case pleasant" choice: proofs and specs are shorter, and runtime checks can still handle edge cases separately.
well_formed s means "all dimensions of s are positive" (recursively).
Instances For
Size positivity #
If all dimensions of a shape are positive, then the total number of scalar elements is positive.
This is a small but useful bridge lemma: many reductions are only defined for nonempty dimensions,
and WellFormed is our standard way of expressing that assumption.
If s.well_formed, then Shape.size s > 0.
If rank s > 0 and s is well-formed, then the last axis rank s - 1 is valid.
This powers many "reduce over last dimension" specs where the axis is computed as rank s - 1.
Typeclass wrapper for Shape.well_formed.
We use a typeclass (instead of passing a well_formed proof everywhere) because it mirrors how
other "side conditions" are handled in the library: call sites stay clean, and instances can be
provided locally (e.g. letI : Shape.WellFormed s := ...) when needed.
- proof : s.wellFormed
Instances
If s is well-formed and n > 0, then .dim n s is well-formed.
Convenience instance: .dim 1 s is well-formed when s is.
Convenience instance: .dim 2 s is well-formed when s is.
If a Fact (n > 0) is in scope, lift it to a Shape.WellFormed (.dim n s) instance.
validAxisLastAuto is a convenience instance for the most common reduction axis:
"reduce over the last dimension".
In PyTorch this is dim=-1 (after normalization). Here we stay in Nat, so the last axis is
rank s - 1, and we require rank s > 0 plus well-formedness so the proof is meaningful.
Convenience instance: infer valid_axis_inst (rank s - 1) s from WellFormed s and rank s > 0.
Bridge lemma: turn a valid_axis proof into a reducibleAlong proof.
Why both exist:
valid_axisis the semantic "this axis makes sense" predicate used in public APIs.reducibleAlongis a structurally convenient predicate for recursion over tensor shapes (it lines up with howTensor.dimis constructed).
This function is the adapter between the two views.
Convert a valid_axis proof into a structurally convenient reducibleAlong proof.
Instances For
padLeft n s prepends n singleton dimensions to a shape.
PyTorch analogy: unsqueeze(0) repeated n times (or equivalently viewing a tensor as having
extra leading dimensions of size 1). This is also the "prepend 1s" step you see in broadcasting.
Prepend n leading singleton dimensions (size 1) to a shape.