TorchLean API

NN.Spec.Core.TensorReductionShape.Reductions

Reductions #

Fold, sum/product/mean/variance, axis reductions, and last-axis reductions.

Reductions #

@[irreducible]
def Spec.Tensor.tensorFoldlSpec {α β : Type} (f : βαβ) (init : β) {s : Shape} :
Tensor α sβ

Left fold over all tensor elements.

Instances For
    @[irreducible]
    def Spec.Tensor.tensorFoldlSpec.go {α β : Type} (f : βαβ) (n : ) (s : Shape) (values : Fin nTensor α s) (i : ) (acc : β) :
    β
    Instances For
      @[irreducible]
      def Spec.Tensor.tensorFoldrSpec {α β : Type} (f : αββ) (init : β) {s : Shape} :
      Tensor α sβ

      Right fold over all tensor elements.

      Instances For
        @[irreducible]
        def Spec.Tensor.tensorFoldrSpec.go {α β : Type} (f : αββ) (n : ) (s : Shape) (values : Fin nTensor α s) (i : ) (acc : β) :
        β
        Instances For
          def Spec.Tensor.sumSpec {α : Type} [Add α] [Zero α] {s : Shape} (t : Tensor α s) :
          α

          Sum all elements of a tensor.

          Instances For
            def Spec.Tensor.prodSpec {α : Type} [Context α] {s : Shape} (t : Tensor α s) :
            α

            Product of all elements of a tensor.

            Instances For
              @[reducible, inline]
              abbrev Spec.Tensor.productSpec {α : Type} [Context α] {s : Shape} (t : Tensor α s) :
              α

              Short name for prodSpec.

              Instances For
                def Spec.Tensor.countSpec {α : Type} {s : Shape} (t : Tensor α s) :

                Count the number of scalar entries in a tensor (= Shape.size).

                Instances For
                  def Spec.Tensor.anySpec {α : Type} {s : Shape} (p : αBool) (t : Tensor α s) :

                  true if any entry satisfies p.

                  Instances For
                    def Spec.Tensor.allSpec {α : Type} {s : Shape} (p : αBool) (t : Tensor α s) :

                    true if all entries satisfy p.

                    Instances For
                      def Spec.Tensor.dotSpec {α : Type} [Context α] {s : Shape} (a b : Tensor α s) :
                      α

                      Dot product: sum (a ⊙ b).

                      Instances For
                        def Spec.Tensor.meanSpec {α : Type} [Context α] {s : Shape} :
                        Tensor α sα

                        Mean of all elements (treats nested dims as one big collection).

                        Instances For
                          def Spec.Tensor.varianceSpec {α : Type} [Context α] {s : Shape} :
                          Tensor α sα

                          Variance of all elements (population variance, divides by n).

                          Instances For

                            Output shape after summing along axis (drops that dimension).

                            Instances For
                              @[simp]

                              simp lemma: dropping axis 1 from a 2D (nQ+1)×(nK+1) shape yields (nQ+1).

                              @[simp]

                              simp lemma: dropping axis 1 from a 2D nQ×nK shape yields nQ.

                              @[simp]

                              simp lemma: dropping axis 3 from a 4D b×h×w×c shape yields b×h×w.

                              @[simp]

                              simp lemma: dropping axis 0 from a positive .dim (n+1) s yields s.

                              @[simp]
                              theorem Spec.Tensor.shape_after_sum_succ {n : } {s : Shape} {k : } :
                              shapeAfterSum (Shape.dim (n + 1) s) (k + 1) = Shape.dim (n + 1) (shapeAfterSum s k)

                              simp lemma: dropping axis k+1 recurses into the tail shape.

                              @[simp]

                              simp lemma: dropping axis 0 from a 2D (kH+1)×(kW+1) yields (kW+1).

                              @[simp]
                              theorem Spec.Tensor.shape_after_sum_zero_alt (n : ) (inner : Shape) :
                              shapeAfterSum (Shape.dim n inner) 0 = inner

                              simp lemma: dropping axis 0 from .dim n inner yields inner (even when n=0).

                              Reflexive broadcast proof (s can broadcast to itself).

                              Instances For

                                Build a broadcast proof from the reduced shape back to the original shape.

                                We use this when a backward pass computes something in the reduced shape (e.g. a mean/variance) and we need to broadcast it back to match the original tensor shape.

                                Instances For
                                  def Spec.Tensor.reduceFirstDim {α : Type} {innerShape : Shape} {n : } (f : {sliceShape : Shape} → Tensor α sliceShapeα) (t : Tensor α (Shape.dim n innerShape)) :
                                  Tensor α innerShape

                                  Reduce a tensor of shape (n, innerShape) by applying f across the first axis.

                                  This is the basic “reduce over axis 0” primitive that we reuse to implement broadcast-adjoints and multi-axis reducers.

                                  Instances For

                                    Reduce a gradient from a broadcast target shape back to the original input shape.

                                    This is the adjoint of broadcastTo for sum-reduction: broadcast duplicates values, so the backward pass sums contributions across broadcasted dimensions.

                                    PyTorch analogy: this is the logic behind "sum over broadcasted dimensions" that happens in autograd for expand + elementwise ops.

                                    def Spec.Tensor.reduceFromBroadcastTo {α : Type} [Add α] [Zero α] {s₁ s₂ : Shape} :
                                    s₁.CanBroadcastTo s₂Tensor α s₂Tensor α s₁

                                    Adjoint of broadcastTo under sum-reduction: collapse broadcasted dimensions by summing.

                                    Instances For
                                      def Spec.Tensor.reduceDim {α : Type} {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (axis : ) (x : Tensor α s) (_h : Shape.reducibleAlong axis s) :
                                      Tensor α (shapeAfterSum s axis)

                                      Generic reduction along a (provably reducible) axis.

                                      reduce_dim f axis x applies f to the slices along axis, and returns a tensor whose shape is shape_after_sum s axis (i.e. that axis is dropped).

                                      Instances For
                                        def Spec.Tensor.reduceDim.aux {α : Type} (f : {sliceShape : Shape} → Tensor α sliceShapeα) {inShape outShape : Shape} (axisAdjusted : ) (h_eq : outShape = shapeAfterSum inShape axisAdjusted) (t : Tensor α inShape) :
                                        Tensor α outShape
                                        Instances For
                                          def Spec.Tensor.reduceSum {α : Type} [Add α] [Zero α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                          Tensor α (shapeAfterSum s axis)

                                          Sum-reduction along a given axis.

                                          Instances For
                                            def Spec.Tensor.reduceSumAuto {α : Type} [Add α] [Zero α] {s : Shape} (axis : ) [h : Shape.valid_axis_inst axis s] (t : Tensor α s) :
                                            Tensor α (shapeAfterSum s axis)

                                            Sum-reduction along axis, with axis validity inferred via valid_axis_inst.

                                            Instances For
                                              def Spec.Tensor.reduceProd {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                              Tensor α (shapeAfterSum s axis)

                                              Product-reduction along a given axis.

                                              Instances For
                                                def Spec.Tensor.reduceProdAuto {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.valid_axis axis s) :
                                                Tensor α (shapeAfterSum s axis)

                                                Product-reduction along axis when you already have a valid_axis proof.

                                                Instances For

                                                  Get the runtime size of the k-th dimension (0-based), if it exists.

                                                  Instances For
                                                    def Spec.Tensor.reduceMean {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                    Tensor α (shapeAfterSum s axis)

                                                    Mean-reduction along a given axis.

                                                    Instances For
                                                      def Spec.Tensor.reduceMeanAuto {α : Type} [Context α] {s : Shape} (axis : ) (h : Shape.valid_axis_inst axis s) (t : Tensor α s) :
                                                      Tensor α (shapeAfterSum s axis)

                                                      Mean-reduction along axis, with axis validity provided as a typeclass argument.

                                                      Instances For
                                                        def Spec.Tensor.reduceSumSquared {α : Type} [Context α] {n : } {s : Shape} (axis : ) (t : Tensor α (Shape.dim n s)) (h : Shape.reducibleAlong axis (Shape.dim n s)) :

                                                        Sum of squares reduced along an axis (helper for variance).

                                                        Instances For
                                                          def Spec.Tensor.reduceVar {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                          Tensor α (shapeAfterSum s axis)

                                                          Variance-reduction along a given axis (population variance, divides by n).

                                                          Instances For
                                                            def Spec.Tensor.reduceVarAuto {α : Type} [Context α] {s : Shape} (axis : ) (h : Shape.valid_axis_inst axis s) (t : Tensor α s) :
                                                            Tensor α (shapeAfterSum s axis)

                                                            Variance-reduction along axis, with axis validity provided as a typeclass argument.

                                                            Instances For
                                                              def Spec.Tensor.reduceMin {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                              Tensor α (shapeAfterSum s axis)

                                                              Min-reduction along a given axis.

                                                              Instances For
                                                                @[irreducible]
                                                                def Spec.Tensor.reduceMin.loop {α : Type} [Context α] (inner : Shape) (n' : ) (f : Fin n'.succTensor α inner) (i : ) (acc : Tensor α inner) (hi : i n') :
                                                                Tensor α inner
                                                                Instances For
                                                                  def Spec.Tensor.reduceMax {α : Type} [Context α] {s : Shape} (axis : ) (t : Tensor α s) (h : Shape.reducibleAlong axis s) :
                                                                  Tensor α (shapeAfterSum s axis)

                                                                  Max-reduction along a given axis.

                                                                  Instances For
                                                                    @[irreducible]
                                                                    def Spec.Tensor.reduceMax.loop {α : Type} [Context α] (inner : Shape) (n' : ) (f : Fin n'.succTensor α inner) (i : ) (acc : Tensor α inner) :
                                                                    Tensor α inner
                                                                    Instances For
                                                                      def Spec.Tensor.reduceMaxAuto {α : Type} [Context α] {s : Shape} (axis : ) [h : Shape.valid_axis_inst axis s] (t : Tensor α s) :
                                                                      Tensor α (shapeAfterSum s axis)

                                                                      Max-reduction along axis, with axis validity inferred via valid_axis_inst.

                                                                      Instances For
                                                                        def Spec.Tensor.reduceLastDim {α : Type} [Context α] {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                        Tensor α (shapeAfterSum s (s.rank - 1))

                                                                        Reduce along the last axis of s (i.e. axis rank s - 1).

                                                                        Instances For
                                                                          def Spec.Tensor.reduceLastDimAuto {α : Type} [Context α] {s : Shape} (f : {sliceShape : Shape} → Tensor α sliceShapeα) (x : Tensor α s) [h : Shape.valid_axis_inst (s.rank - 1) s] :
                                                                          Tensor α (shapeAfterSum s (s.rank - 1))

                                                                          Like reduce_last_dim, but infers axis validity via valid_axis_inst.

                                                                          Instances For
                                                                            def Spec.Tensor.reduceMeanLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                            Tensor α (shapeAfterSum s (s.rank - 1))

                                                                            Mean-reduce along the last axis.

                                                                            Instances For
                                                                              def Spec.Tensor.reduceSumLast {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h : Shape.reducibleAlong ((Shape.dim seqLen (Shape.dim embedDim Shape.scalar)).rank - 1) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :

                                                                              Sum-reduce along the last axis of a 2D tensor (seqLen, embedDim).

                                                                              Instances For
                                                                                def Spec.Tensor.reduceProdLast {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (h : Shape.reducibleAlong ((Shape.dim seqLen (Shape.dim embedDim Shape.scalar)).rank - 1) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) :

                                                                                Product-reduce along the last axis of a 2D tensor (seqLen, embedDim).

                                                                                Instances For
                                                                                  def Spec.Tensor.reduceMaxLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                  Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                  Max-reduce along the last axis.

                                                                                  Instances For
                                                                                    def Spec.Tensor.reduceMinLast {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.reducibleAlong (s.rank - 1) s) :
                                                                                    Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                    Min-reduce along the last axis.

                                                                                    Instances For
                                                                                      def Spec.Tensor.reduceVarLast {α : Type} [Context α] {n : } {s : Shape} (x : Tensor α (Shape.dim n s)) (h : Shape.reducibleAlong ((Shape.dim n s).rank - 1) (Shape.dim n s)) :

                                                                                      Variance-reduce along the last axis (specialized to a leading batch dimension).

                                                                                      Instances For
                                                                                        def Spec.Tensor.reduceVarLastGeneral {α : Type} [Context α] {n : } {s : Shape} (x : Tensor α (Shape.dim n s)) (h : Shape.valid_axis_inst ((Shape.dim n s).rank - 1) (Shape.dim n s)) :

                                                                                        Variance-reduce along the last axis (with axis validity as a typeclass argument).

                                                                                        Instances For
                                                                                          def Spec.Tensor.reduceMeanLastGeneral {α : Type} [Context α] {s : Shape} (x : Tensor α s) (h : Shape.valid_axis_inst (s.rank - 1) s) :
                                                                                          Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                          Mean-reduce along the last axis (with axis validity as a typeclass argument).

                                                                                          Instances For
                                                                                            def Spec.Tensor.reduceMeanLastGeneralWf {α : Type} [Context α] {s : Shape} (x : Tensor α s) [_h_wf : s.WellFormed] (_h_rank : s.rank > 0) (h_valid : Shape.valid_axis_inst (s.rank - 1) s) :
                                                                                            Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                            Mean-reduce along the last axis, specialized for proofs that assume well-formedness.

                                                                                            Instances For
                                                                                              def Spec.Tensor.reduceSumLastGeneral {α : Type} [Context α] {s : Shape} (x : Tensor α s) [h : Shape.valid_axis_inst (s.rank - 1) s] :
                                                                                              Tensor α (shapeAfterSum s (s.rank - 1))

                                                                                              Sum-reduce along the last axis (with axis validity inferred via valid_axis_inst).

                                                                                              Instances For