TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Shape

Shape #

Additional analytic (HasFDerivAt) tape nodes for shape permutations.

These nodes are linear/isometric and are useful for models that do explicit reshaping and dimension permutations (e.g. Multi-Head Attention head splitting/combining).

theorem Proofs.Autograd.TapeNodes.ShapeOps.inner_castVec_left {n m : } (h : n = m) (x : Vec n) (y : Vec m) :
inner (castVec h x) y = inner x (castVec y)

Move castVec across the left argument of an inner product.

theorem Proofs.Autograd.TapeNodes.ShapeOps.castVec_proof_irrel {n m : } (h₁ h₂ : n = m) (v : Vec n) :
castVec h₁ v = castVec h₂ v

castVec is proof-irrelevant in its equality argument.

reshape is linear: on vectors it is just a type cast along Shape.size equality. We implement it as a Node to keep the DAG theorem applicable.

noncomputable def Proofs.Autograd.TapeNodes.ShapeOps.reshape {Γ : List Spec.Shape} {s₁ s₂ : Spec.Shape} (idx : Idx Γ s₁) (h : s₁.size = s₂.size) :
Node Γ s₂

reshape node: reinterpret the same underlying coordinates as a different shape.

This is only definable when Shape.size s₁ = Shape.size s₂; at the vector level it is a cast.

PyTorch analogue: view/reshape operations that do not change the total number of elements. https://pytorch.org/docs/stable/tensor_view.html

Instances For
    noncomputable def Proofs.Autograd.TapeNodes.ShapeOps.reshapeFderiv {Γ : List Spec.Shape} {s₁ s₂ : Spec.Shape} (idx : Idx Γ s₁) (h : s₁.size = s₂.size) :

    NodeFDerivCorrect for reshape (it is linear/isometric).

    Instances For

      flatten is a specialization of reshape to the canonical vector shape (.dim (Shape.size s) .scalar).

      flatten node: specialization of reshape to the canonical vector shape (.dim (Shape.size s) .scalar).

      PyTorch analogue: flatten when applied to a contiguous tensor. https://pytorch.org/docs/stable/generated/torch.flatten.html

      Instances For

        NodeFDerivCorrect for flatten.

        Instances For
          noncomputable def Proofs.Autograd.TapeNodes.ShapeOps.reindexVec {n m : } (e : Fin n Fin m) :
          Vec nVec m

          Reindex a vector along a Fin equivalence (coordinate permutation/renaming).

          Instances For
            noncomputable def Proofs.Autograd.TapeNodes.ShapeOps.reindexLin {n m : } (e : Fin n Fin m) :

            The linear map induced by reindexVec.

            Instances For

              Move reindexVec across the left argument of an inner product.

              Underlying coordinate permutation for swapping the first two axes of a 3D tensor.

              Instances For
                noncomputable def Proofs.Autograd.TapeNodes.ShapeOps.swapFirstTwo3d {Γ : List Spec.Shape} {m n : } {rest : Spec.Shape} (idx : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n rest))) :

                Swap the first two axes of a 3D tensor shape: .dim m (.dim n rest) ↦ .dim n (.dim m rest).

                This is implemented as a coordinate permutation (a linear isometry).

                PyTorch analogue: transpose(0, 1) on a 3D tensor. https://pytorch.org/docs/stable/generated/torch.transpose.html

                Instances For

                  NodeFDerivCorrect for swap_first_two3d (linear coordinate permutation).

                  Instances For
                    def Proofs.Autograd.TapeNodes.ShapeOps.transposeLastTwoEquiv (a b c : ) :
                    Fin (a * (b * (c * 1))) Fin (a * (c * (b * 1)))

                    Underlying coordinate permutation for transposing the last two axes of a 3D tensor.

                    Instances For

                      Transpose the last two axes of a 3D tensor: .dim a (.dim b (.dim c .scalar)) ↦ .dim a (.dim c (.dim b .scalar)).

                      This is another coordinate permutation used in attention (switching K to Kᵀ while keeping head/batch axes).

                      Instances For

                        NodeFDerivCorrect for transpose3d_last_two (linear coordinate permutation).

                        Instances For