TorchLean API

NN.Spec.Core.Tensor.Core

Core tensor datatype (Spec.Tensor) #

This file defines the foundational, shape-indexed tensor type used throughout TorchLean's spec layer:

Tensor α s

Why an inductive / functional representation? #

Instead of storing a flat array plus a shape, the spec tensor is a function from indices:

This has three practical benefits:

  1. Shape safety is enforced by the type.
  2. Proofs are natural: you reason by recursion on the Shape / Tensor constructors.
  3. No layout commitment: the spec layer doesn’t bake in row-major vs column-major storage.

For long executable runs, repeated functional updates can create “closure chains”. Use Tensor.materialize (documented below) to rebuild a tensor into an array-backed normal form.

inductive Spec.Tensor (α : Type) :

Shape-indexed tensor datatype for the spec layer.

This is a functional representation:

  • a scalar tensor is just an α,
  • an n-dimensional tensor is a function Fin n → Tensor α s.

This keeps proofs and shape-safe programming simple, and avoids committing to a concrete memory layout in the spec layer.

Instances For

    Runtime note: materialization #

    Tensor α s is a functional representation (Fin n → ...). This is excellent for proofs, but repeated updates (for example, many SGD steps) can build deep chains of closures (fun i => ... (old (old (old i))) ...). Evaluating those chains later becomes progressively more expensive.

    Tensor.materialize rebuilds a tensor into an array-backed normal form (at every dimension), which keeps long-running training loops from degrading.

    It is extentionally the identity (same mathematical tensor), but it is much friendlier to the runtime evaluator.

    def Spec.Tensor.materialize {α : Type} {s : Shape} :
    Tensor α sTensor α s

    Rebuild a tensor into an array-backed normal form (performance helper).

    Instances For
      @[simp]
      theorem Spec.Tensor.materialize_eq {α : Type} {s : Shape} (t : Tensor α s) :

      Tensor.materialize preserves tensor values (it is extensionally the identity).

      def Spec.Tensor.default {α : Type} [Inhabited α] {s : Shape} :
      Tensor α s

      Default tensor value for any shape (uses Inhabited.default at scalars).

      Instances For
        @[reducible]
        instance Spec.Tensor.inhabited {α : Type} [Inhabited α] {s : Shape} :

        Make Tensor α s inhabited for any shape s.

        def Spec.shapeOf {α : Type} {s : Shape} :
        Tensor α sShape

        Recover the (data) shape from a tensor value.

        Instances For

          Extract the scalar value from a scalar tensor.

          Instances For

            Inject a scalar into a scalar tensor.

            Instances For
              @[simp]
              theorem Spec.Tensor.toScalar_ofScalar {α : Type} (x : α) :

              toScalar (ofScalar x) = x.

              @[simp]

              ofScalar (toScalar t) = t for scalar tensors.

              Equivalence between Tensor α .scalar and α (useful to reuse algebra instances).

              Instances For
                @[implicit_reducible]

                AddCommMonoid on scalar tensors, transported from α via Tensor.scalarEquiv.

                Equivalence between vectors-as-tensors and functions Fin n → α.

                Instances For
                  @[implicit_reducible]

                  AddCommMonoid on vector tensors, transported from Fin n → α via Tensor.dimScalarEquiv.

                  def Spec.Tensor.castShape {α : Type} {s₁ s₂ : Shape} (t : Tensor α s₁) (h : s₁ = s₂) :
                  Tensor α s₂

                  Cast a tensor along an equality of shapes.

                  Instances For
                    def Spec.Tensor.castVecDim {α : Type} {n m : } (h : n = m) (t : Tensor α (Shape.dim n Shape.scalar)) :

                    Cast a vector tensor along an equality of dimensions.

                    Instances For

                      Cast lemmas #

                      In dependently-typed proofs (especially graph/tape correctness proofs), the same cast may arise with different proof terms. Since equality proofs are proof-irrelevant, we provide a few small normalization lemmas for Tensor.cast_shape.

                      @[simp]
                      theorem Spec.Tensor.cast_shape_rfl {α : Type} {s : Shape} (t : Tensor α s) :
                      t.castShape = t

                      Casting a tensor along rfl is the identity.

                      @[simp]
                      theorem Spec.Tensor.cast_shape_self {α : Type} {s : Shape} (t : Tensor α s) (h : s = s) :
                      t.castShape h = t

                      Casting a tensor along a reflexive equality proof is the identity.

                      @[simp]
                      theorem Spec.Tensor.cast_shape_trans {α : Type} {s₁ s₂ s₃ : Shape} (t : Tensor α s₁) (h₁₂ : s₁ = s₂) (h₂₃ : s₂ = s₃) :
                      (t.castShape h₁₂).castShape h₂₃ = t.castShape

                      Tensor.cast_shape composes associatively (cast-by-eq is just Eq.rec).

                      theorem Spec.Tensor.cast_shape_proof_irrel {α : Type} {s₁ s₂ : Shape} (t : Tensor α s₁) {p q : s₁ = s₂} :

                      Proof-irrelevance for Tensor.cast_shape.

                      theorem Spec.Tensor.eqRec_eq_cast_shape {α : Type} {s₁ s₂ : Shape} (t : Tensor α s₁) (h : s₁ = s₂) :
                      h t = t.castShape h

                      Rewrite Eq.rec (h ▸ t) as Tensor.cast_shape for uniformity in larger proofs.

                      theorem Spec.Tensor.eqRec_proof_irrel {α : Type} {s₁ s₂ : Shape} (t : Tensor α s₁) {p q : s₁ = s₂} :
                      p t = q t

                      Proof-irrelevance for Eq.rec casts of tensors.

                      Indexing helpers #

                      Indexing design notes:

                      PyTorch analogy:

                      def Spec.getSpec {α : Type} {s : Shape} (t : Tensor α s) :
                      List Option α

                      Get a scalar by a multi‑index (list of Nats).

                      Instances For
                        @[simp]
                        theorem Spec.get_spec_scalar_nil {α : Type} (value : α) :
                        getSpec (Tensor.scalar value) [] = some value
                        @[simp]
                        theorem Spec.get_spec_scalar_cons {α : Type} (value : α) (i : ) (is : List ) :
                        getSpec (Tensor.scalar value) (i :: is) = none
                        @[simp]
                        theorem Spec.get_spec_dim_nil {α : Type} {n : } {s : Shape} (values : Fin nTensor α s) :
                        @[simp]
                        theorem Spec.get_spec_dim_cons {α : Type} {n : } {s : Shape} (values : Fin nTensor α s) (i : ) (is : List ) :
                        getSpec (Tensor.dim values) (i :: is) = if h : i < n then getSpec (values i, h) is else none
                        def Spec.getAtSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (i : Fin n) :
                        Tensor α s

                        Extract the subtensor at index i along the outermost dimension.

                        Instances For
                          def Spec.get {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (i : Fin n) :
                          Tensor α s

                          Alias for get_at_spec (the standard spec-level indexing helper).

                          Instances For
                            @[implicit_reducible]
                            instance Spec.instGetElemTensorDimFinTrue {α : Type} {n : } {s : Shape} :
                            GetElem (Tensor α (Shape.dim n s)) (Fin n) (Tensor α s) fun (x : Tensor α (Shape.dim n s)) (x_1 : Fin n) => True

                            Enable Lean’s indexing syntax for spec tensors.

                            After this instance, you can write t[i] as notation for Spec.get t i.

                            We use the domain condition True because the index is already a Fin n, so it is always in-bounds by construction.

                            def Spec.Tensor.vecGet {α : Type} {n : } (x : Tensor α (Shape.dim n Shape.scalar)) (i : Fin n) :
                            α

                            Extract the i-th entry from a vector tensor.

                            Instances For
                              def Spec.get2 {α : Type} {m n : } (A : Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) (i : Fin m) (j : Fin n) :
                              α

                              Matrix element access: get2 A i j returns A[i, j] as a scalar.

                              Instances For
                                @[implicit_reducible]
                                instance Spec.instGetElemTensorDimScalarProdFinTrue {α : Type} {m n : } :
                                GetElem (Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) (Fin m × Fin n) α fun (x : Tensor α (Shape.dim m (Shape.dim n Shape.scalar))) (x_1 : Fin m × Fin n) => True

                                Enable Lean’s indexing syntax for matrix-shaped scalar tensors.

                                After this instance, you can write A[(i, j)] as notation for Spec.get2 A i j.

                                get_at_or_zero is a total variant of get_spec used in places where a default value is more convenient than Option.

                                We keep both get_spec and get_at_or_zero because they serve different roles:

                                def Spec.getAtOrZero {α : Type} [Zero α] {s : Shape} (t : Tensor α s) :
                                List α

                                Total tensor indexing: returns 0 when the index list is out of bounds.

                                Instances For
                                  @[simp]
                                  @[simp]
                                  theorem Spec.get_at_or_zero_scalar_cons {α : Type} [Zero α] (v : α) (i : ) (is : List ) :
                                  @[simp]
                                  theorem Spec.get_at_or_zero_dim_nil {α : Type} [Zero α] {n : } {s : Shape} (values : Fin nTensor α s) :
                                  @[simp]
                                  theorem Spec.get_at_or_zero_dim_cons {α : Type} [Zero α] {n : } {s : Shape} (values : Fin nTensor α s) (i : ) (is : List ) :
                                  getAtOrZero (Tensor.dim values) (i :: is) = if h : i < n then getAtOrZero (values i, h) is else 0
                                  def Spec.finZero {n : } (h : 0 < n) :
                                  Fin n

                                  Construct Fin n value 0 given a proof that 0 < n.

                                  Instances For
                                    def Spec.getHead {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) :
                                    Option (Tensor α s)

                                    Get the first element of a 1st‑dimension tensor (if nonempty).

                                    Instances For
                                      def Spec.getTail {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) :
                                      Option (Tensor α (Shape.dim (n - 1) s))

                                      Drop the first element of a 1st‑dimension tensor (if nonempty).

                                      Instances For
                                        def Spec.tensorCast {α : Type} {s : Shape} (t : Shape) (h : s = t) :
                                        Tensor α sTensor α t

                                        Cast a tensor along an equality of shapes.

                                        Instances For
                                          @[simp]
                                          theorem Spec.tensor_cast_eq_cast_shape {α : Type} {s t : Shape} (h : s = t) (x : Tensor α s) :

                                          tensor_cast is definitionally Tensor.cast_shape (a uniform cast normal form).

                                          def Spec.replicate {α : Type} {s : Shape} :

                                          Replicate a scalar tensor to any shape.

                                          Instances For

                                            Slicing helpers.

                                            We keep slice_spec as a focused "first-axis select" operation since it shows up all over the place in spec definitions.

                                            PyTorch analogy: slice_spec t i is t[i].

                                            def Spec.sliceSpec {α : Type} {n : } {s : Shape} :
                                            Tensor α (Shape.dim n s)Fin nTensor α s

                                            Slice a tensor along its first axis: slice_spec t i = t[i].

                                            Instances For
                                              def Spec.sliceRangeSpec {α : Type} {n : } {s : Shape} (t : Tensor α (Shape.dim n s)) (start len : ) (h : len + start n) :
                                              Tensor α (Shape.dim len s)

                                              Slice a contiguous range along the first axis.

                                              This is the spec-level analogue of t[start : start+len] in array/tensor libraries.

                                              Instances For

                                                collect_at_index_spec is a "transpose-like" helper that pulls a fixed position out of every batch entry.

                                                This is a small but surprisingly useful building block in attention-like code and dataset manipulations, where you frequently want to reorganize (batch, n, ...) into (n, batch, ...) without committing to a concrete memory layout.

                                                def Spec.collectAtIndexSpec {β : Type} {b n : } {shape : Shape} (f : Fin bTensor β (Shape.dim n shape)) (j : Fin n) :
                                                Tensor β (shape.appendDim b)

                                                Collect the j-th element from each batch entry, producing a tensor with batch dimension moved to the end.

                                                This is a small "transpose-like" helper used in attention-like code and dataset reshaping.

                                                Instances For