TorchLean API

NN.Spec.Core.TensorArray

TensorArray: a simple array-backed tensor representation #

Spec.Tensor is the canonical, shape-indexed tensor type for the spec layer. It is great for proofs and pure definitions, but it is not always the most convenient representation for:

TensorArray.Tensor is a lightweight alternative:

The bridge back to Spec.Tensor lives in NN/Spec/Core/TensorBridge.lean.

Why this representation exists:

structure TensorArray.Tensor (α : Type) (shape : List Nat) :

A tensor is represented as a flat array of elements and a shape (as a list of dimensions). The shape_valid proof ensures the array size matches the product of the shape dimensions.

  • data : Array α

    Flat row-major data buffer.

  • shape_valid : self.data.size = List.foldl (fun (x1 x2 : Nat) => x1 * x2) 1 shape

    Proof that the buffer length matches the product of the runtime dimensions.

Instances For

    Product of dimensions for a runtime shape list.

    This is the runtime analogue of Spec.Shape.size for Spec.Shape. We keep it as a def (not just a local let) because it appears everywhere shape_valid is constructed or rewritten.

    Instances For
      def TensorArray.ofArray {α : Type} (arr : Array α) (shape : List Nat) (h : arr.size = shapeProd shape) :
      Tensor α shape

      Build a tensor from an array when you already have a size proof.

      Design choice:

      • Tensor stores the shape at the type level (Tensor α shape), so callers must provide h : arr.size = shapeProd shape.
      • This makes "I reshaped / reinterpreted the data" a conscious action with an explicit proof.
      Instances For
        @[simp]

        Base case for shapeProd: the empty shape has product 1.

        theorem TensorArray.foldl_mul_factor (n : Nat) (ns : List Nat) :
        List.foldl (fun (x1 x2 : Nat) => x1 * x2) n ns = n * List.foldl (fun (x1 x2 : Nat) => x1 * x2) 1 ns

        Helper lemma: factoring a left-multiplication out of the foldl product.

        This is used to prove shapeProd_cons and similar "shape product algebra" facts.

        @[simp]
        theorem TensorArray.shapeProd_cons (n : Nat) (ns : List Nat) :
        shapeProd (n :: ns) = n * shapeProd ns

        Step case for shapeProd: product of (n :: ns) is n * shapeProd ns.

        theorem TensorArray.List.length_zipWith {α β γ : Type} (f : αβγ) (l1 : List α) (l2 : List β) :

        Length of List.zipWith is the minimum of the input lengths.

        We keep these small list lemmas local to this file because they are only needed to justify shape_valid proofs for array operations like zipWith.

        theorem TensorArray.List.length_map {α β : Type} (f : αβ) (l : List α) :

        Length of a map is the same as the input length.

        Length of flatten is the sum of lengths of each inner list.

        This kind of fact comes up any time we build an Array/List by flattening a list-of-lists.

        Compute the flat index for a given multi-index.

        Returns none if the indices are out of bounds or the rank mismatches.

        Instances For
          def TensorArray.flatIndex (shape indices : List Nat) :
          Instances For
            theorem TensorArray.flatIndexAux_lt (shape indices : List Nat) (acc idx : Nat) :
            flatIndexAux shape indices acc = some idxidx < (acc + 1) * shapeProd shape

            flatIndexAux returns an index that is bounded by the "mixed-radix" size implied by the remaining shape.

            Intuition: starting with accumulator acc, the recursion computes something of the form acc * shapeProd shape + tail, where tail < shapeProd shape.

            theorem TensorArray.flatIndex_lt_shapeProd (shape indices : List Nat) (idx : Nat) :
            flatIndex shape indices = some idxidx < shapeProd shape

            If flatIndex succeeds, the resulting index is in-bounds for the flattened tensor.

            def TensorArray.get? {α : Type} {shape : List Nat} [Inhabited α] (t : Tensor α shape) (indices : List Nat) :

            Get an element at the given multi-index.

            Returns none if:

            • the rank mismatches, or
            • any index is out of bounds.

            This is the array-backed analogue of Spec.get_spec.

            Instances For
              def TensorArray.map {α β : Type} {shape : List Nat} (f : αβ) (t : Tensor α shape) :
              Tensor β shape

              Map a function over all elements (shape preserved).

              The only subtlety is the shape_valid proof: mapping doesn't change array length.

              Instances For
                def TensorArray.map2 {α β γ : Type} {shape : List Nat} (f : αβγ) (t₁ : Tensor α shape) (t₂ : Tensor β shape) :
                Tensor γ shape

                Elementwise binary operation (shape preserved).

                We require both tensors to have the same shape at the type level, so shape mismatches are unrepresentable here.

                Instances For
                  def TensorArray.sum {α : Type} [Add α] [Zero α] {shape : List Nat} (t : Tensor α shape) :
                  α

                  Reduce by summing all elements (flattened).

                  This ignores the tensor's rank and sums over data directly. PyTorch analogy: t.sum() (over all axes).

                  Instances For
                    def TensorArray.reshape {α : Type} {shape1 shape2 : List Nat} (t : Tensor α shape1) (h : shapeProd shape1 = shapeProd shape2) :
                    Tensor α shape2

                    Reshape a tensor to a new shape with the same number of elements.

                    This is "view"-style: it reuses the same underlying data array. The proof h is the only thing that changes.

                    Instances For
                      def TensorArray.full {α : Type} (shape : List Nat) (val : α) :
                      Tensor α shape

                      Create a tensor filled with a constant value.

                      PyTorch analogy: torch.full(shape, val).

                      Instances For
                        def TensorArray.add {α : Type} [Add α] {shape : List Nat} (t₁ t₂ : Tensor α shape) :
                        Tensor α shape

                        Elementwise addition.

                        Instances For
                          def TensorArray.mul {α : Type} [Mul α] {shape : List Nat} (t₁ t₂ : Tensor α shape) :
                          Tensor α shape

                          Elementwise multiplication.

                          Instances For
                            def TensorArray.relu {α : Type} [LT α] [Zero α] [DecidableLT α] {shape : List Nat} (t : Tensor α shape) :
                            Tensor α shape

                            ReLU activation (elementwise max with zero).

                            This is written in the simplest "executable" style: a branch on x > 0. For interval/real semantics, use Spec layer ops; this module is for array-backed computations.

                            Instances For
                              def TensorArray.matvec {α : Type} [Add α] [Mul α] [Zero α] [Inhabited α] {m n : Nat} (mat : Tensor α [m, n]) (vec : Tensor α [n]) :

                              Matrix-vector multiplication: (m x n) matrix times (n) vector gives (m) vector.

                              This is a direct reference implementation intended for small sizes and clarity. If you need performance, you generally want the runtime/TorchLean path instead.

                              Instances For
                                def TensorArray.linear {α : Type} [Add α] [Mul α] [Zero α] [Inhabited α] {m n : Nat} (W : Tensor α [m, n]) (b : Tensor α [m]) (x : Tensor α [n]) :

                                Linear layer: y = W x + b.

                                PyTorch analogy: torch.nn.Linear(n, m) forward pass with weight W and bias b.

                                Instances For