TorchLean API

NN.Spec.Core.TensorReductionShape

Reductions, flatten/unflatten, and shape-changing helpers #

This module is where “shape-aware” operations live:

Because shapes are indexed in types, many of these definitions necessarily carry equalities like Shape.size s = ... under the hood.

Tip: when you need to transport a tensor across a proved shape equality, use:

Prefer abbrevs in user-facing code so common shape equalities remain definitional rather than requiring transport proofs.

PyTorch mental model:

The difference is that our shapes live in types, so the spec definitions must be explicit about:

Naming note (sequence concatenation):

References / analogies (shape intuition, not semantics):

Flatten a tensor into a 1‑D vector (length = Shape.size s).

The order is outermost‑dimension major (row‑major w.r.t. the shape tree). For proofs, the key invariant is that the output length matches Shape.size.

Why this exists: a lot of shape-changing ops are easiest to specify as "flatten, then rebuild", and this is also the bridge we use for some runtime interop where we want a plain sequence of scalars (e.g. importing weights or serializing test vectors).

Instances For

    Unflatten a 1‑D vector back into a tensor of a given shape.

    PyTorch analogy: flat.view(shape) (assuming the element count matches). This is the inverse of flattenSpec up to the ordering convention.

    Instances For

      flattenSpec / unflattenSpec round-trip lemmas #

      These are shape-transport facts: they justify treating flattenSpec/unflattenSpec like reshape/view in PyTorch, provided you keep the element count consistent.

      PyTorch references:

      Important nuance:

      If a shape has Shape.size s = 0, then it contains no scalar leaves (it has a 0-length dimension somewhere). In that case, there is essentially only one possible tensor value of shape s (up to definitional equality), because at the 0-length dimension the indexing function has domain Fin 0.

      We use this as a “vacuity” lemma to avoid needing division/modulo arithmetic when Shape.size s = 0.

      Round-trip unflatten ∘ flatten = id.

      This is the spec-layer analogue of reshape/view round-tripping in PyTorch when the element count matches.

      Round-trip flatten ∘ unflatten = id.

      This is the spec-layer analogue of flattening a reshaped/viewed tensor in PyTorch.

      Convenience corollary: the unflatten ∘ flatten round-trip in the common well-formed regime.

      def Spec.Tensor.reshapeSpec {α : Type} [Inhabited α] {s₁ s₂ : Shape} (t : Tensor α s₁) (h : s₁.size = s₂.size) :
      Tensor α s₂

      Reshape a tensor, given a proof that the number of elements matches.

      Instances For
        def Spec.Tensor.reshapeExplicitSpec {α : Type} [Inhabited α] {s₁ s₂ : Shape} (t : Tensor α s₁) (h : s₁.size = s₂.size) :
        Tensor α s₂

        Reshape with an explicit equality rewrite (sometimes easier for the elaborator).

        Instances For
          def Spec.Tensor.sequenceFin {α : Type} {s : Shape} {n : } (f : Fin nOption (Tensor α s)) :

          Given a partial function Fin n → Option (Tensor α s), build a tensor if all succeed.

          Instances For
            def Spec.Tensor.broadcastFill {α : Type} [Inhabited α] (s : Shape) :
            αTensor α s

            Build a tensor filled with a constant, without using fill (used in broadcasts).

            Instances For

              Broadcasting #

              def Spec.Tensor.broadcastTo {α : Type} [Inhabited α] {s₁ s₂ : Shape} :
              s₁.CanBroadcastTo s₂Tensor α s₁Tensor α s₂

              Broadcast a tensor along a Shape.CanBroadcastTo proof (spec-level analogue of torch.broadcast_to).

              Instances For

                Broadcasted maps #

                def Spec.Tensor.broadcastLike {α : Type} [Inhabited α] {s : Shape} (_template : Tensor α s) (t : Tensor α Shape.scalar) :
                Tensor α s

                Broadcast a scalar tensor to match a template tensor's shape.

                This is a small convenience wrapper used by specs that want "like" broadcasting without spelling out the Shape.CanBroadcastTo evidence.

                Instances For
                  def Spec.Tensor.mapScalarLeft {α : Type} (f : ααα) (x : α) {s : Shape} :
                  Tensor α sTensor α s

                  Helper: map a scalar on the left over any tensor shape.

                  Instances For
                    def Spec.Tensor.mapScalarRight {α : Type} (f : ααα) (y : α) {s : Shape} :
                    Tensor α sTensor α s

                    Helper: map a scalar on the right over any tensor shape.

                    Instances For
                      def Spec.Tensor.broadcastMapTo {α : Type} [Inhabited α] (f : ααα) {s₁ s₂ t : Shape} (cbx : s₁.CanBroadcastTo t) (cby : s₂.CanBroadcastTo t) :
                      Tensor α s₁Tensor α s₂Tensor α t

                      Binary element-wise operation with broadcasting to an explicit target shape.

                      This is the helper you typically want in spec code:

                      • pick the output shape t,
                      • broadcast each operand to t,
                      • then map2_spec the pointwise operation.

                      PyTorch analogy: f(x, y) where x and/or y are broadcastable to a common shape. We make the common shape explicit instead of "discovering" it, because at the spec layer we want:

                      • predictable typing,
                      • a single source of truth for what the output shape is.
                      Instances For

                        Reductions #

                        @[irreducible]
                        def Spec.Tensor.tensorFoldlSpec {α β : Type} (f : βαβ) (init : β) {s : Shape} :
                        Tensor α sβ

                        Left fold over all tensor elements.

                        Instances For
                          @[irreducible]
                          def Spec.Tensor.tensorFoldlSpec.go {α β : Type} (f : βαβ) (n : ) (s : Shape) (values : Fin nTensor α s) (i : ) (acc : β) :
                          β
                          Instances For
                            @[irreducible]
                            def Spec.Tensor.tensorFoldrSpec {α β : Type} (f : αββ) (init : β) {s : Shape} :
                            Tensor α sβ

                            Right fold over all tensor elements.

                            Instances For
                              @[irreducible]
                              def Spec.Tensor.tensorFoldrSpec.go {α β : Type} (f : αββ) (n : ) (s : Shape) (values : Fin nTensor α s) (i : ) (acc : β) :
                              β
                              Instances For
                                def Spec.Tensor.sumSpec {α : Type} [Add α] [Zero α] {s : Shape} (t : Tensor α s) :
                                α

                                Sum all elements of a tensor.

                                Instances For
                                  def Spec.Tensor.prodSpec {α : Type} [Context α] {s : Shape} (t : Tensor α s) :
                                  α

                                  Product of all elements of a tensor.

                                  Instances For
                                    @[reducible, inline]
                                    abbrev Spec.Tensor.productSpec {α : Type} [Context α] {s : Shape} (t : Tensor α s) :
                                    α

                                    Short name for prodSpec.

                                    Instances For
                                      def Spec.Tensor.countSpec {α : Type} {s : Shape} (t : Tensor α s) :

                                      Count the number of scalar entries in a tensor (= Shape.size).

                                      Instances For
                                        def Spec.Tensor.anySpec {α : Type} {s : Shape} (p : αBool) (t : Tensor α s) :

                                        true if any entry satisfies p.

                                        Instances For
                                          def Spec.Tensor.allSpec {α : Type} {s : Shape} (p : αBool) (t : Tensor α s) :

                                          true if all entries satisfy p.

                                          Instances For
                                            def Spec.Tensor.dotSpec {α : Type} [Context α] {s : Shape} (a b : Tensor α s) :
                                            α

                                            Dot product: sum (a ⊙ b).

                                            Instances For
                                              def Spec.Tensor.meanSpec {α : Type} [Context α] {s : Shape} :
                                              Tensor α sα

                                              Mean of all elements (treats nested dims as one big collection).

                                              Instances For
                                                def Spec.Tensor.varianceSpec {α : Type} [Context α] {s : Shape} :
                                                Tensor α sα

                                                Variance of all elements (population variance, divides by n).

                                                Instances For

                                                  Output shape after summing along axis (drops that dimension).

                                                  Instances For
                                                    @[simp]

                                                    simp lemma: dropping axis 1 from a 2D (nQ+1)×(nK+1) shape yields (nQ+1).

                                                    @[simp]

                                                    simp lemma: dropping axis 1 from a 2D nQ×nK shape yields nQ.

                                                    @[simp]

                                                    simp lemma: dropping axis 3 from a 4D b×h×w×c shape yields b×h×w.

                                                    @[simp]

                                                    simp lemma: dropping axis 0 from a positive .dim (n+1) s yields s.

                                                    @[simp]
                                                    theorem Spec.Tensor.shape_after_sum_succ {n : } {s : Shape} {k : } :
                                                    shapeAfterSum (Shape.dim (n + 1) s) (k + 1) = Shape.dim (n + 1) (shapeAfterSum s k)

                                                    simp lemma: dropping axis k+1 recurses into the tail shape.

                                                    @[simp]

                                                    simp lemma: dropping axis 0 from a 2D (kH+1)×(kW+1) yields (kW+1).

                                                    @[simp]
                                                    theorem Spec.Tensor.shape_after_sum_zero_alt (n : ) (inner : Shape) :
                                                    shapeAfterSum (Shape.dim n inner) 0 = inner

                                                    simp lemma: dropping axis 0 from .dim n inner yields inner (even when n=0).

                                                    Reflexive broadcast proof (s can broadcast to itself).

                                                    Instances For

                                                      Build a broadcast proof from the reduced shape back to the original shape.

                                                      We use this when a backward pass computes something in the reduced shape (e.g. a mean/variance) and we need to broadcast it back to match the original tensor shape.

                                                      Instances For
                                                        def Spec.Tensor.reduceFirstDim {α : Type} {innerShape : Shape} {n : } (f : {sliceShape : Shape} → Tensor α sliceShapeα) (t : Tensor α (Shape.dim n innerShape)) :
                                                        Tensor α innerShape

                                                        Reduce a tensor of shape (n, innerShape) by applying f across the first axis.

                                                        This is the basic “reduce over axis 0” primitive that we reuse to implement broadcast-adjoints and multi-axis reducers.

                                                        Instances For

                                                          Reduce a gradient from a broadcast target shape back to the original input shape.

                                                          This is the adjoint of broadcastTo for sum-reduction: broadcast duplicates values, so the backward pass sums contributions across broadcasted dimensions.

                                                          PyTorch analogy: this is the logic behind "sum over broadcasted dimensions" that happens in autograd for expand + elementwise ops.

                                                          def Spec.Tensor.reduceFromBroadcastTo {α : Type} [Add α] [Zero α] {s₁ s₂ : Shape} :
                                                          s₁.CanBroadcastTo s₂Tensor α s₂Tensor α s₁

                                                          Adjoint of broadcastTo under sum-reduction: collapse broadcasted dimensions by summing.

                                                          Instances For
                                                            def Spec.Tensor.reduceDim {α : Type} {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (axis : ) (x : Tensor α s) (_h : Shape.reducibleAlong axis s) :
                                                            Tensor α (shapeAfterSum s axis)

                                                            Generic reduction along a (provably reducible) axis.

                                                            reduce_dim f axis x applies f to the slices along axis, and returns a tensor whose shape is shape_after_sum s axis (i.e. that axis is dropped).

                                                            Instances For
                                                              def Spec.Tensor.reduceDim.aux {α : Type} (f : {sliceShape : Shape} → Tensor α sliceShapeα) {inShape outShape : Shape} (axisAdjusted : ) (h_eq : outShape = shapeAfterSum inShape axisAdjusted) (t : Tensor α inShape) :
                                                              Tensor α outShape
                                                              Instances For
                                                                def Spec.Tensor.reduceSum {α : Type} [Add α] [Zero α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                Tensor α (shapeAfterSum s axis)

                                                                Sum-reduction along a given axis.

                                                                Instances For
                                                                  def Spec.Tensor.reduceSumAuto {α : Type} [Add α] [Zero α] {s : Shape} (axis : ) [h : Shape.valid_axis_inst axis s] (t : Tensor α s) :
                                                                  Tensor α (shapeAfterSum s axis)

                                                                  Sum-reduction along axis, with axis validity inferred via valid_axis_inst.

                                                                  Instances For
                                                                    def Spec.Tensor.reduceProd {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                    Tensor α (shapeAfterSum s axis)

                                                                    Product-reduction along a given axis.

                                                                    Instances For
                                                                      def Spec.Tensor.reduceProdAuto {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.valid_axis axis s) :
                                                                      Tensor α (shapeAfterSum s axis)

                                                                      Product-reduction along axis when you already have a valid_axis proof.

                                                                      Instances For

                                                                        Get the runtime size of the k-th dimension (0-based), if it exists.

                                                                        Instances For
                                                                          def Spec.Tensor.reduceMean {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                          Tensor α (shapeAfterSum s axis)

                                                                          Mean-reduction along a given axis.

                                                                          Instances For
                                                                            def Spec.Tensor.reduceMeanAuto {α : Type} [Context α] {s : Shape} (axis : ) (h : Shape.valid_axis_inst axis s) (t : Tensor α s) :
                                                                            Tensor α (shapeAfterSum s axis)

                                                                            Mean-reduction along axis, with axis validity provided as a typeclass argument.

                                                                            Instances For
                                                                              def Spec.Tensor.reduceSumSquared {α : Type} [Context α] {n : } {s : Shape} (axis : ) (t : Tensor α (Shape.dim n s)) (h : Shape.reducibleAlong axis (Shape.dim n s)) :

                                                                              Sum of squares reduced along an axis (helper for variance).

                                                                              Instances For
                                                                                def Spec.Tensor.reduceVar {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                                Tensor α (shapeAfterSum s axis)

                                                                                Variance-reduction along a given axis (population variance, divides by n).

                                                                                Instances For
                                                                                  def Spec.Tensor.reduceVarAuto {α : Type} [Context α] {s : Shape} (axis : ) (h : Shape.valid_axis_inst axis s) (t : Tensor α s) :
                                                                                  Tensor α (shapeAfterSum s axis)

                                                                                  Variance-reduction along axis, with axis validity provided as a typeclass argument.

                                                                                  Instances For
                                                                                    def Spec.Tensor.reduceMin {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                                    Tensor α (shapeAfterSum s axis)

                                                                                    Min-reduction along a given axis.

                                                                                    Instances For
                                                                                      @[irreducible]
                                                                                      def Spec.Tensor.reduceMin.loop {α : Type} [Context α] (inner : Shape) (n' : ) (f : Fin n'.succTensor α inner) (i : ) (acc : Tensor α inner) (hi : i n') :
                                                                                      Tensor α inner
                                                                                      Instances For
                                                                                        def Spec.Tensor.reduceMax {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                                        Tensor α (shapeAfterSum s axis)

                                                                                        Max-reduction along a given axis.

                                                                                        Instances For
                                                                                          @[irreducible]
                                                                                          def Spec.Tensor.reduceMax.loop {α : Type} [Context α] (inner : Shape) (n' : ) (f : Fin n'.succTensor α inner) (i : ) (acc : Tensor α inner) :
                                                                                          Tensor α inner
                                                                                          Instances For
                                                                                            def Spec.Tensor.reduceMaxAuto {α : Type} [Context α] {s : Shape} (axis : ) [h : Shape.valid_axis_inst axis s] (t : Tensor α s) :
                                                                                            Tensor α (shapeAfterSum s axis)

                                                                                            Max-reduction along axis, with axis validity inferred via valid_axis_inst.

                                                                                            Instances For
                                                                                              def Spec.Tensor.reduceLastDim {α : Type} [Context α] {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                              Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                              Reduce along the last axis of s (i.e. axis rank s - 1).

                                                                                              Instances For
                                                                                                def Spec.Tensor.reduceLastDimAuto {α : Type} [Context α] {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (x : Tensor α s) [h : Shape.valid_axis_inst (s.rank - 1) s] :
                                                                                                Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                Like reduce_last_dim, but infers axis validity via valid_axis_inst.

                                                                                                Instances For
                                                                                                  def Spec.Tensor.reduceMeanLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                                  Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                  Mean-reduce along the last axis.

                                                                                                  Instances For
                                                                                                    def Spec.Tensor.reduceSumLast {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h : Shape.reducibleAlong ((Shape.dim seqLen (Shape.dim embedDim Shape.scalar)).rank - 1) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :

                                                                                                    Sum-reduce along the last axis of a 2D tensor (seqLen, embedDim).

                                                                                                    Instances For
                                                                                                      def Spec.Tensor.reduceProdLast {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h : Shape.reducibleAlong ((Shape.dim seqLen (Shape.dim embedDim Shape.scalar)).rank - 1) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :

                                                                                                      Product-reduce along the last axis of a 2D tensor (seqLen, embedDim).

                                                                                                      Instances For
                                                                                                        def Spec.Tensor.reduceMaxLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                                        Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                        Max-reduce along the last axis.

                                                                                                        Instances For
                                                                                                          def Spec.Tensor.reduceMinLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                                          Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                          Min-reduce along the last axis.

                                                                                                          Instances For
                                                                                                            def Spec.Tensor.reduceVarLast {α : Type} [Context α] {n : } {s : Shape} (x : Tensor α (Shape.dim n s)) (h : Shape.reducibleAlong ((Shape.dim n s).rank - 1) (Shape.dim n s)) :

                                                                                                            Variance-reduce along the last axis (specialized to a leading batch dimension).

                                                                                                            Instances For
                                                                                                              def Spec.Tensor.reduceVarLastGeneral {α : Type} [Context α] {n : } {s : Shape} (x : Tensor α (Shape.dim n s)) (h : Shape.valid_axis_inst ((Shape.dim n s).rank - 1) (Shape.dim n s)) :

                                                                                                              Variance-reduce along the last axis (with axis validity as a typeclass argument).

                                                                                                              Instances For
                                                                                                                def Spec.Tensor.reduceMeanLastGeneral {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.valid_axis_inst (s.rank - 1) s) :
                                                                                                                Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                                Mean-reduce along the last axis (with axis validity as a typeclass argument).

                                                                                                                Instances For
                                                                                                                  def Spec.Tensor.reduceMeanLastGeneralWf {α : Type} [Context α] {s : Shape} (x : Tensor α s) [_h_wf : s.WellFormed] (_h_rank : s.rank > 0) (h_valid : Shape.valid_axis_inst (s.rank - 1) s) :
                                                                                                                  Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                                  Mean-reduce along the last axis, specialized for proofs that assume well-formedness.

                                                                                                                  Instances For
                                                                                                                    def Spec.Tensor.reduceSumLastGeneral {α : Type} [Context α] {s : Shape} (x : Tensor α s) [h : Shape.valid_axis_inst (s.rank - 1) s] :
                                                                                                                    Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                                                    Sum-reduce along the last axis (with axis validity inferred via valid_axis_inst).

                                                                                                                    Instances For

                                                                                                                      Transpose a matrix (m×n) into (n×m).

                                                                                                                      PyTorch analogy: A.transpose(0, 1) or A.T for 2D tensors.

                                                                                                                      Instances For

                                                                                                                        Permute a 3D tensor from (a,b,c) to (b,c,a).

                                                                                                                        Instances For

                                                                                                                          Permute a 3D tensor from (a,b,c) to (c,a,b).

                                                                                                                          Instances For

                                                                                                                            Swap the last two axes of a 3D tensor: (a,b,c) to (a,c,b).

                                                                                                                            Instances For
                                                                                                                              def Spec.Tensor.swapFirstTwoSpec {α : Type} {m n : } {s : Shape} (t : Tensor α (Shape.dim m (Shape.dim n s))) :

                                                                                                                              Swap the first two dimensions of a tensor (m,n,...) to (n,m,...).

                                                                                                                              Instances For
                                                                                                                                def Spec.Tensor.swapAtDepthHelper {β : Type} {shape : Shape} (tensor : Tensor β shape) (d : ) :

                                                                                                                                Helper for swapping adjacent dims at a given depth (see Shape.swapAdjacentAtDepth).

                                                                                                                                Instances For
                                                                                                                                  def Spec.Tensor.swapAtDepthSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (depth : ) :

                                                                                                                                  Swap adjacent dimensions at a given depth inside a leading batch dimension.

                                                                                                                                  Instances For

                                                                                                                                    Backward pass for matrix multiplication: returns (dA, dB) given dC.

                                                                                                                                    PyTorch analogy: if C = A @ B, then:

                                                                                                                                    • dA = dC @ Bᵀ
                                                                                                                                    • dB = Aᵀ @ dC
                                                                                                                                    Instances For
                                                                                                                                      def Spec.Tensor.bmmSpec {α : Type} [Add α] [Mul α] [Zero α] {batch m n p : } (A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar)))) (B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar)))) :

                                                                                                                                      Batched matrix multiplication: [batch,m,n] × [batch,n,p] → [batch,m,p].

                                                                                                                                      Instances For
                                                                                                                                        def Spec.Tensor.bmmBackwardSpec {α : Type} [Add α] [Mul α] [Zero α] {batch m n p : } (A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar)))) (B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar)))) (dC : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim p Shape.scalar)))) :

                                                                                                                                        Backward pass for batched matrix multiplication.

                                                                                                                                        Instances For
                                                                                                                                          def Spec.Tensor.matchShape {α : Type} {s : Shape} (t : Tensor α s) :

                                                                                                                                          Runtime check that a tensor value matches a runtime Shape.

                                                                                                                                          We use this in a few “dynamic” utilities where we have a runtime shape value and want to guard access/casts in a total way.

                                                                                                                                          Instances For
                                                                                                                                            def Spec.Tensor.concatSpec {α : Type} [Inhabited α] {n d : } (headCount : ) (tensors : List (Tensor α (Shape.dim n (Shape.dim d Shape.scalar)))) (_h_len : tensors.length = headCount) :
                                                                                                                                            Tensor α (Shape.dim n (Shape.dim (headCount * d) Shape.scalar))

                                                                                                                                            Concatenate a list of (n,d) tensors along the last axis, producing (n, headCount*d).

                                                                                                                                            This is mainly used by attention blocks that split/merge heads.

                                                                                                                                            PyTorch analogy: torch.cat(heads, dim=-1) after splitting heads, followed by a reshape.

                                                                                                                                            Instances For
                                                                                                                                              def Spec.Tensor.concatSpec.buildRow {α : Type} {n d : } (i : Fin n) (ts : List (Tensor α (Shape.dim n (Shape.dim d Shape.scalar)))) :
                                                                                                                                              List α
                                                                                                                                              Instances For

                                                                                                                                                Concatenate two vectors by appending v2 after v1.

                                                                                                                                                Instances For
                                                                                                                                                  def Spec.Tensor.concatDim0Spec {α : Type} {n m : } {s : Shape} (t1 : Tensor α (Shape.dim n s)) (t2 : Tensor α (Shape.dim m s)) :
                                                                                                                                                  Tensor α (Shape.dim (n + m) s)

                                                                                                                                                  Concatenate along axis 0 (append t2 after t1).

                                                                                                                                                  Instances For

                                                                                                                                                    Slicing / concatenation on the leading axis #

                                                                                                                                                    concat_dim0_spec is the "append on axis 0" primitive that powers many higher-level utilities (sequence concatenation, channel skip connections, etc.).

                                                                                                                                                    For backprop and for "undoing" concatenations, it is convenient to have an explicit slice operation. We keep the API compact and index-safe:

                                                                                                                                                    def Spec.Tensor.sliceRange0Spec {α : Type} {n : } {s : Shape} (start len : ) (h : len + start n) (t : Tensor α (Shape.dim n s)) :
                                                                                                                                                    Tensor α (Shape.dim len s)

                                                                                                                                                    Slice len entries along axis 0, starting at start.

                                                                                                                                                    This is the simplest "range slice" one typically needs to express:

                                                                                                                                                    • taking the first n channels/tokens,
                                                                                                                                                    • extracting the skip-connection half after a concat,
                                                                                                                                                    • implementing take/drop without changing the inner shape.

                                                                                                                                                    The proof len + start ≤ n makes the slice total (no out-of-bounds behavior).

                                                                                                                                                    Instances For
                                                                                                                                                      def Spec.Tensor.concatDim0BackwardSpec {α : Type} {n m : } {s : Shape} (δ : Tensor α (Shape.dim (n + m) s)) :
                                                                                                                                                      Tensor α (Shape.dim n s) × Tensor α (Shape.dim m s)

                                                                                                                                                      Backward (adjoint) of concat_dim0_spec.

                                                                                                                                                      If y = concat_dim0_spec x1 x2, then in reverse-mode we split the upstream gradient δy into:

                                                                                                                                                      • δx1 = the first n entries of δy,
                                                                                                                                                      • δx2 = the last m entries of δy.
                                                                                                                                                      Instances For
                                                                                                                                                        def Spec.Tensor.sliceRange0BackwardSpec {α : Type} [Zero α] {n : } {s : Shape} (start len : ) (_h : len + start n) (δ : Tensor α (Shape.dim len s)) :
                                                                                                                                                        Tensor α (Shape.dim n s)

                                                                                                                                                        Backward (adjoint) of slice_range0_spec.

                                                                                                                                                        If y = slice_range0_spec start len x, then slice_range0_backward_spec start len δy re-inserts the gradient into the original shape and fills everything outside the slice with zeros.

                                                                                                                                                        Instances For
                                                                                                                                                          def Spec.Tensor.concatSequenceSpec {α : Type} {seqLen1 seqLen2 hiddenSize : } (seq1 : Tensor α (Shape.dim seqLen1 (Shape.dim hiddenSize Shape.scalar))) (seq2 : Tensor α (Shape.dim seqLen2 (Shape.dim hiddenSize Shape.scalar))) :
                                                                                                                                                          Tensor α (Shape.dim (seqLen1 + seqLen2) (Shape.dim hiddenSize Shape.scalar))

                                                                                                                                                          Concatenate two sequences along time (axis 0), producing a longer sequence.

                                                                                                                                                          If seq1 : (seqLen1 x hidden) and seq2 : (seqLen2 x hidden), this returns (seqLen1 + seqLen2) x hidden by appending seq2 after seq1.

                                                                                                                                                          Do not confuse this with Spec.concatSequenceSpec (defined in NN.Spec.Core.Sequence), which concatenates along the feature dimension for same-length sequences.

                                                                                                                                                          Instances For
                                                                                                                                                            def Spec.Tensor.concatSequenceInnerSpec {α : Type} {seqLen hiddenSize1 hiddenSize2 : } (seq1 : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize1 Shape.scalar))) (seq2 : Tensor α (Shape.dim seqLen (Shape.dim hiddenSize2 Shape.scalar))) :
                                                                                                                                                            Tensor α (Shape.dim seqLen (Shape.dim (hiddenSize1 + hiddenSize2) Shape.scalar))

                                                                                                                                                            Concatenate two sequences along the feature dimension (inner axis).

                                                                                                                                                            Instances For
                                                                                                                                                              def Spec.Tensor.expandToColSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) :

                                                                                                                                                              Expand a (n, s) tensor into (n, 1, s) by inserting a trailing dimension of size 1.

                                                                                                                                                              PyTorch analogy: t.unsqueeze(-1) for a rank-1 outer dimension (or unsqueeze(dim=1) in 2D terms).

                                                                                                                                                              Instances For

                                                                                                                                                                Same as expand_to_col_spec, specialized to vectors.

                                                                                                                                                                Instances For
                                                                                                                                                                  def Spec.Tensor.squeezeColSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n (Shape.dim 1 s))) :
                                                                                                                                                                  Tensor α (Shape.dim n s)

                                                                                                                                                                  Squeeze a (n,1,s) tensor back into (n,s) by dropping the singleton dimension.

                                                                                                                                                                  Instances For

                                                                                                                                                                    Same as squeeze_col_spec, specialized to vectors.

                                                                                                                                                                    Instances For
                                                                                                                                                                      def Spec.Tensor.unsqueezeSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (_dim : ) :

                                                                                                                                                                      Unsqueeze (insert a singleton dim). Currently implemented as expand_to_col_spec.

                                                                                                                                                                      Core uses singleton insertion mainly for column vectors, so this operation is specialized to that use case. General axis insertion can extend this definition.

                                                                                                                                                                      Instances For

                                                                                                                                                                        Turn a vector (n) into a batch of size 1: (1,n).

                                                                                                                                                                        Instances For
                                                                                                                                                                          def Spec.Tensor.batchToEndSpec {α : Type} {batch : } {s : Shape} (t : Tensor α (Shape.dim batch s)) :
                                                                                                                                                                          Tensor α (s.appendDim batch)

                                                                                                                                                                          Move a leading batch dimension to the innermost position.

                                                                                                                                                                          Instances For

                                                                                                                                                                            Convert channel-first images (b,c,h,w) into channel-last (b,h,w,c).

                                                                                                                                                                            Instances For