Reductions, flatten/unflatten, and shape-changing helpers #
This module is where “shape-aware” operations live:
flattenSpec/unflattenSpec(convert between a tensor and a flat vector of lengthShape.size)- broadcasting maps that change the output shape
Because shapes are indexed in types, many of these definitions necessarily carry equalities like
Shape.size s = ... under the hood.
Tip: when you need to transport a tensor across a proved shape equality, use:
Tensor.cast_shape(defined inNN/Spec/Core/Tensor/Core.lean)
Prefer abbrevs in user-facing code so common shape equalities remain definitional rather than
requiring transport proofs.
PyTorch mental model:
flattenSpec/unflattenSpeccorrespond totorch.flattenandview/reshapeon a contiguous tensor.- broadcasting (
broadcastTo/broadcastMapTo) corresponds toexpand/broadcast_toplus elementwise ops. - reductions (
reduceSum,reduceMean,reduceVar,reduceMax, andreduceMin) correspond tosum/mean/var/amax/aminalong a chosen axis.
The difference is that our shapes live in types, so the spec definitions must be explicit about:
- what the target/output shape is,
- and why the axis is valid / reducible.
Naming note (sequence concatenation):
- This file defines
Spec.Tensor.concatSequenceSpecfor concatenating along the time axis (axis 0), producing a longer sequence. NN.Spec.Core.SequencedefinesSpec.concatSequenceSpecfor concatenating along the feature axis (inner axis) for same-length sequences. The names are intentionally similar, but they are different operations living in different namespaces (Spec.TensorvsSpec).
References / analogies (shape intuition, not semantics):
- PyTorch
torch.flatten: https://pytorch.org/docs/stable/generated/torch.flatten.html - PyTorch
torch.Tensor.reshape: https://pytorch.org/docs/stable/generated/torch.Tensor.reshape.html - PyTorch
torch.Tensor.view: https://pytorch.org/docs/stable/generated/torch.Tensor.view.html - PyTorch
torch.Tensor.expand: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html - PyTorch
torch.sum: https://pytorch.org/docs/stable/generated/torch.sum.html - PyTorch
torch.mean: https://pytorch.org/docs/stable/generated/torch.mean.html
Flatten a tensor into a 1‑D vector (length = Shape.size s).
The order is outermost‑dimension major (row‑major w.r.t. the shape tree).
For proofs, the key invariant is that the output length matches Shape.size.
Why this exists: a lot of shape-changing ops are easiest to specify as "flatten, then rebuild", and this is also the bridge we use for some runtime interop where we want a plain sequence of scalars (e.g. importing weights or serializing test vectors).
Instances For
Unflatten a 1‑D vector back into a tensor of a given shape.
PyTorch analogy: flat.view(shape) (assuming the element count matches).
This is the inverse of flattenSpec up to the ordering convention.
Instances For
flattenSpec / unflattenSpec round-trip lemmas #
These are shape-transport facts: they justify treating flattenSpec/unflattenSpec like
reshape/view in PyTorch, provided you keep the element count consistent.
PyTorch references:
torch.flatten: https://pytorch.org/docs/stable/generated/torch.flatten.htmlTensor.view/torch.reshape: https://pytorch.org/docs/stable/generated/torch.Tensor.view.htmltorch.reshape: https://pytorch.org/docs/stable/generated/torch.reshape.html
Important nuance:
- PyTorch allows zero-sized dimensions, and its reshape/flatten semantics remain total.
- Our spec definitions are also total (they use
Inhabited.defaultfor unreachable branches), which keeps everything executable, but can make “inverse” proofs a bit index-heavy. The theorems below show that the round-trips do work for the spec definitions as written.
If a shape has Shape.size s = 0, then it contains no scalar leaves (it has a 0-length
dimension somewhere). In that case, there is essentially only one possible tensor value of shape
s (up to definitional equality), because at the 0-length dimension the indexing function has
domain Fin 0.
We use this as a “vacuity” lemma to avoid needing division/modulo arithmetic when Shape.size s = 0.
Round-trip flatten ∘ unflatten = id.
This is the spec-layer analogue of flattening a reshaped/viewed tensor in PyTorch.
Convenience corollary: the unflatten ∘ flatten round-trip in the common well-formed regime.
Broadcasting #
Broadcast a tensor along a Shape.CanBroadcastTo proof (spec-level analogue of
torch.broadcast_to).
Instances For
Broadcasted maps #
Broadcast a scalar tensor to match a template tensor's shape.
This is a small convenience wrapper used by specs that want "like" broadcasting without spelling
out the Shape.CanBroadcastTo evidence.
Instances For
Binary element-wise operation with broadcasting to an explicit target shape.
This is the helper you typically want in spec code:
- pick the output shape
t, - broadcast each operand to
t, - then
map2_specthe pointwise operation.
PyTorch analogy: f(x, y) where x and/or y are broadcastable to a common shape.
We make the common shape explicit instead of "discovering" it, because at the spec layer we want:
- predictable typing,
- a single source of truth for what the output shape is.
Instances For
Reductions #
Left fold over all tensor elements.
Instances For
Right fold over all tensor elements.
Instances For
Output shape after summing along axis (drops that dimension).
Instances For
simp lemma: dropping axis 1 from a 2D (nQ+1)×(nK+1) shape yields (nQ+1).
simp lemma: dropping axis 1 from a 2D nQ×nK shape yields nQ.
simp lemma: dropping axis 3 from a 4D b×h×w×c shape yields b×h×w.
simp lemma: dropping axis 0 from a positive .dim (n+1) s yields s.
simp lemma: dropping axis k+1 recurses into the tail shape.
simp lemma: dropping axis 0 from a 2D (kH+1)×(kW+1) yields (kW+1).
simp lemma: dropping axis 0 from .dim n inner yields inner (even when n=0).
Reflexive broadcast proof (s can broadcast to itself).
Instances For
Build a broadcast proof from the reduced shape back to the original shape.
We use this when a backward pass computes something in the reduced shape (e.g. a mean/variance) and we need to broadcast it back to match the original tensor shape.
Instances For
Reduce a tensor of shape (n, innerShape) by applying f across the first axis.
This is the basic “reduce over axis 0” primitive that we reuse to implement broadcast-adjoints and multi-axis reducers.
Instances For
Reduce a gradient from a broadcast target shape back to the original input shape.
This is the adjoint of broadcastTo for sum-reduction: broadcast duplicates values, so the
backward pass sums contributions across broadcasted dimensions.
PyTorch analogy: this is the logic behind "sum over broadcasted dimensions" that happens in
autograd for expand + elementwise ops.
Adjoint of broadcastTo under sum-reduction: collapse broadcasted dimensions by summing.
Instances For
Generic reduction along a (provably reducible) axis.
reduce_dim f axis x applies f to the slices along axis, and returns a tensor whose shape is
shape_after_sum s axis (i.e. that axis is dropped).
Instances For
Sum-reduction along a given axis.
Instances For
Sum-reduction along axis, with axis validity inferred via valid_axis_inst.
Instances For
Product-reduction along a given axis.
Instances For
Product-reduction along axis when you already have a valid_axis proof.
Instances For
Mean-reduction along a given axis.
Instances For
Mean-reduction along axis, with axis validity provided as a typeclass argument.
Instances For
Sum of squares reduced along an axis (helper for variance).
Instances For
Variance-reduction along a given axis (population variance, divides by n).
Instances For
Variance-reduction along axis, with axis validity provided as a typeclass argument.
Instances For
Min-reduction along a given axis.
Instances For
Max-reduction along a given axis.
Instances For
Max-reduction along axis, with axis validity inferred via valid_axis_inst.
Instances For
Reduce along the last axis of s (i.e. axis rank s - 1).
Instances For
Like reduce_last_dim, but infers axis validity via valid_axis_inst.
Instances For
Mean-reduce along the last axis.
Instances For
Sum-reduce along the last axis of a 2D tensor (seqLen, embedDim).
Instances For
Product-reduce along the last axis of a 2D tensor (seqLen, embedDim).
Instances For
Max-reduce along the last axis.
Instances For
Min-reduce along the last axis.
Instances For
Mean-reduce along the last axis (with axis validity as a typeclass argument).
Instances For
Mean-reduce along the last axis, specialized for proofs that assume well-formedness.
Instances For
Sum-reduce along the last axis (with axis validity inferred via valid_axis_inst).
Instances For
Transpose a matrix (m×n) into (n×m).
PyTorch analogy: A.transpose(0, 1) or A.T for 2D tensors.
Instances For
Permute a 3D tensor from (a,b,c) to (b,c,a).
Instances For
Permute a 3D tensor from (a,b,c) to (c,a,b).
Instances For
Swap the last two axes of a 3D tensor: (a,b,c) to (a,c,b).
Instances For
Helper for swapping adjacent dims at a given depth (see Shape.swapAdjacentAtDepth).
Instances For
Backward pass for matrix multiplication: returns (dA, dB) given dC.
PyTorch analogy: if C = A @ B, then:
dA = dC @ BᵀdB = Aᵀ @ dC
Instances For
Batched matrix multiplication: [batch,m,n] × [batch,n,p] → [batch,m,p].
Instances For
Backward pass for batched matrix multiplication.
Instances For
Concatenate a list of (n,d) tensors along the last axis, producing (n, headCount*d).
This is mainly used by attention blocks that split/merge heads.
PyTorch analogy: torch.cat(heads, dim=-1) after splitting heads, followed by a reshape.
Instances For
Concatenate two vectors by appending v2 after v1.
Instances For
Slicing / concatenation on the leading axis #
concat_dim0_spec is the "append on axis 0" primitive that powers many higher-level utilities
(sequence concatenation, channel skip connections, etc.).
For backprop and for "undoing" concatenations, it is convenient to have an explicit slice operation. We keep the API compact and index-safe:
slice_range0_spec start lenselectslenconsecutive entries starting atstartalong axis 0.concat_dim0_backward_specis the adjoint ofconcat_dim0_spec(splits a gradient tensor).
Slice len entries along axis 0, starting at start.
This is the simplest "range slice" one typically needs to express:
- taking the first
nchannels/tokens, - extracting the skip-connection half after a concat,
- implementing
take/dropwithout changing the inner shape.
The proof len + start ≤ n makes the slice total (no out-of-bounds behavior).
Instances For
Backward (adjoint) of concat_dim0_spec.
If y = concat_dim0_spec x1 x2, then in reverse-mode we split the upstream gradient δy into:
δx1= the firstnentries ofδy,δx2= the lastmentries ofδy.
Instances For
Backward (adjoint) of slice_range0_spec.
If y = slice_range0_spec start len x, then slice_range0_backward_spec start len δy re-inserts
the gradient into the original shape and fills everything outside the slice with zeros.
Instances For
Concatenate two sequences along time (axis 0), producing a longer sequence.
If seq1 : (seqLen1 x hidden) and seq2 : (seqLen2 x hidden), this returns
(seqLen1 + seqLen2) x hidden by appending seq2 after seq1.
Do not confuse this with Spec.concatSequenceSpec (defined in NN.Spec.Core.Sequence), which
concatenates along the feature dimension for same-length sequences.
Instances For
Concatenate two sequences along the feature dimension (inner axis).
Instances For
Same as expand_to_col_spec, specialized to vectors.
Instances For
Same as squeeze_col_spec, specialized to vectors.
Instances For
Unsqueeze (insert a singleton dim). Currently implemented as expand_to_col_spec.
Core uses singleton insertion mainly for column vectors, so this operation is specialized to that use case. General axis insertion can extend this definition.
Instances For
Turn a vector (n) into a batch of size 1: (1,n).
Instances For
Convert channel-first images (b,c,h,w) into channel-last (b,h,w,c).