TorchLean API

NN.Runtime.Autograd.Engine.TapeM

TapeM #

Tape-building convenience API.

The core autograd runtime (Runtime.Autograd.Tape) is pure and explicitly threaded: each op returns an updated tape plus the new node id. This makes the engine easy to reason about and convenient for proofs, but it can feel verbose in user code.

Runtime.Autograd.TapeM is a small StateT wrapper that threads the tape implicitly, closer to the "define ops; then call backward" ergonomics users expect from frameworks like PyTorch.

For training scripts/tests, also see NN.Runtime.Autograd.Utils which provides small helpers for common patterns (reading scalar losses, extracting typed grads, simple SGD loops).

Reading map #

@[reducible, inline]

A convenient tape-builder monad.

TapeM α β is StateT (Tape α) Result β: a pure tape threaded implicitly with errors reported via Except String. This mirrors the common eager style of building a computation and then calling backward, similar to PyTorch's imperative API, but remains purely functional.

Instances For
    def Runtime.Autograd.TapeM.run {α β : Type} (t : Tape α) (m : TapeM α β) :
    Result (β × Tape α)

    Run a TapeM computation from an initial tape, returning both the result and the final tape.

    Instances For
      def Runtime.Autograd.TapeM.eval {α β : Type} (t : Tape α) (m : TapeM α β) :

      Evaluate a TapeM computation, discarding the final tape.

      Instances For
        def Runtime.Autograd.TapeM.exec {α β : Type} (t : Tape α) (m : TapeM α β) :

        Execute a TapeM computation, discarding the produced value and returning the final tape.

        Instances For

          Get the current tape state.

          Instances For

            Replace the current tape state.

            Instances For
              def Runtime.Autograd.TapeM.leaf {α : Type} {s : Spec.Shape} (value : Spec.Tensor α s) (name : Option String := none) (requires_grad : Bool := true) :

              Create a leaf node holding a concrete tensor value.

              A leaf is the "input tensor" analogue: it has no parents. Setting requires_grad := true corresponds to PyTorch tensors created with requires_grad=True.

              Instances For

                StateT wrapper around Tape.add. PyTorch comparison: torch.add(a, b).

                Instances For
                  def Runtime.Autograd.TapeM.sub {α : Type} [Sub α] [Zero α] [DecidableEq Spec.Shape] {s : Spec.Shape} (aId bId : ) :

                  StateT wrapper around Tape.sub. PyTorch comparison: torch.sub(a, b).

                  Instances For

                    StateT wrapper around Tape.mul. PyTorch comparison: torch.mul(a, b).

                    Instances For
                      def Runtime.Autograd.TapeM.scale {α : Type} [Mul α] [DecidableEq Spec.Shape] {s : Spec.Shape} (xId : ) (c : α) :

                      StateT wrapper around Tape.scale. PyTorch comparison: c * x / torch.mul(x, c).

                      Instances For
                        def Runtime.Autograd.TapeM.abs {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (xId : ) :

                        StateT wrapper around Tape.abs. PyTorch comparison: torch.abs(x).

                        Instances For
                          def Runtime.Autograd.TapeM.sqrt {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (xId : ) :

                          StateT wrapper around Tape.sqrt. PyTorch comparison: torch.sqrt(x).

                          Instances For
                            def Runtime.Autograd.TapeM.clamp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (xId : ) (minVal maxVal : α) :

                            StateT wrapper around Tape.clamp. PyTorch comparison: torch.clamp(x, min, max).

                            Instances For
                              def Runtime.Autograd.TapeM.max {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (aId bId : ) :

                              StateT wrapper around Tape.max. PyTorch comparison: torch.maximum(a, b).

                              Instances For
                                def Runtime.Autograd.TapeM.min {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (aId bId : ) :

                                StateT wrapper around Tape.min. PyTorch comparison: torch.minimum(a, b).

                                Instances For
                                  def Runtime.Autograd.TapeM.relu {α : Type} [Mul α] [Zero α] [Max α] [One α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {s : Spec.Shape} (xId : ) :

                                  StateT wrapper around Tape.relu. PyTorch comparison: torch.nn.functional.relu(x).

                                  Instances For
                                    def Runtime.Autograd.TapeM.linear {α : Type} [Add α] [Mul α] [Zero α] [DecidableEq Spec.Shape] {inDim outDim : } (wId bId xId : ) :

                                    StateT wrapper around Tape.linear. PyTorch comparison: torch.nn.functional.linear.

                                    Instances For
                                      def Runtime.Autograd.TapeM.matmul {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {m n p : } (aId bId : ) :

                                      StateT wrapper around Tape.matmul. PyTorch comparison: torch.matmul(a, b).

                                      Instances For
                                        def Runtime.Autograd.TapeM.concatVectors {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {n m : } (aId bId : ) :

                                        StateT wrapper around Tape.concat_vectors. PyTorch comparison: torch.cat([a,b], dim=0) for vectors.

                                        Instances For
                                          def Runtime.Autograd.TapeM.conv2d {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernelId biasId inputId : ) :

                                          StateT wrapper around Tape.conv2d.

                                          PyTorch comparison: torch.nn.functional.conv2d (this codebase uses a single-image specialization; see Tape.conv2d for the exact shape conventions).

                                          Instances For
                                            def Runtime.Autograd.TapeM.convTranspose {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {d inC outC : } {kernel stride padding inSpatial : Vector d} (kernelId biasId inputId : ) (name : String := "conv_transpose") :

                                            StateT wrapper around Tape.conv_transpose.

                                            PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample (no batch axis).

                                            Instances For
                                              def Runtime.Autograd.TapeM.convTranspose2d {α : Type} [Context α] [DecidableEq Spec.Shape] {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernelId biasId inputId : ) :

                                              StateT wrapper around Tape.conv_transpose2d.

                                              PyTorch comparison: torch.nn.functional.conv_transpose2d (single-image specialization; see Tape.conv_transpose2d for exact shape conventions).

                                              Instances For
                                                def Runtime.Autograd.TapeM.maxPool2d {α : Type} [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (xId : ) :

                                                StateT wrapper around Tape.max_pool2d. PyTorch comparison: torch.nn.functional.max_pool2d.

                                                Instances For
                                                  def Runtime.Autograd.TapeM.maxPool2dPad {α : Type} [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } {h1 : kH 0} {h2 : kW 0} (xId : ) :

                                                  StateT wrapper around Tape.max_pool2d_pad. PyTorch comparison: torch.nn.functional.max_pool2d with padding.

                                                  Instances For
                                                    def Runtime.Autograd.TapeM.smoothMaxPool2d {α : Type} [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } {h1 : kH 0} {h2 : kW 0} (xId : ) (beta : α) :

                                                    StateT wrapper around Tape.smooth_max_pool2d.

                                                    This is a differentiable (soft) approximation to max-pooling controlled by beta.

                                                    Instances For
                                                      def Runtime.Autograd.TapeM.avgPool2d {α : Type} [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride : } (h1 : kH 0) (h2 : kW 0) (xId : ) :

                                                      StateT wrapper around Tape.avg_pool2d. PyTorch comparison: torch.nn.functional.avg_pool2d.

                                                      Instances For
                                                        def Runtime.Autograd.TapeM.avgPool2dPad {α : Type} [Context α] [DecidableEq Spec.Shape] {kH kW inH inW inC stride padding : } (h1 : kH 0) (h2 : kW 0) (xId : ) :

                                                        StateT wrapper around Tape.avg_pool2d_pad. PyTorch comparison: torch.nn.functional.avg_pool2d with padding.

                                                        Instances For
                                                          def Runtime.Autograd.TapeM.layerNorm {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (xId gammaId betaId : ) :

                                                          StateT wrapper around Tape.layer_norm. PyTorch comparison: torch.nn.LayerNorm.

                                                          Instances For
                                                            def Runtime.Autograd.TapeM.batchnormChannelFirst {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (xId gammaId betaId : ) :

                                                            StateT wrapper around Tape.batchnorm_channel_first. PyTorch comparison: torch.nn.BatchNorm2d in channel-first layout.

                                                            Instances For
                                                              def Runtime.Autograd.TapeM.multiHeadAttention {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {n numHeads dModel headDim : } (h1 : n 0) (wqId wkId wvId woId xId : ) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

                                                              StateT wrapper around Tape.multi_head_attention. PyTorch comparison: torch.nn.MultiheadAttention / scaled dot-product attention.

                                                              Instances For
                                                                def Runtime.Autograd.TapeM.mseLoss {α : Type} [Add α] [Sub α] [Mul α] [Div α] [Zero α] [One α] [Coe α] [DecidableEq Spec.Shape] {s : Spec.Shape} (yhatId targetId : ) :

                                                                StateT wrapper around Tape.mse_loss. PyTorch comparison: torch.nn.functional.mse_loss.

                                                                Instances For

                                                                  StateT wrapper around Tape.sigmoid. PyTorch comparison: torch.sigmoid.

                                                                  Instances For

                                                                    StateT wrapper around Tape.tanh. PyTorch comparison: torch.tanh.

                                                                    Instances For

                                                                      StateT wrapper around Tape.softmax (last-axis). PyTorch comparison: torch.softmax(x, dim=-1).

                                                                      Instances For

                                                                        StateT wrapper around Tape.softplus. PyTorch comparison: torch.nn.functional.softplus.

                                                                        Instances For

                                                                          StateT wrapper around Tape.exp. PyTorch comparison: torch.exp.

                                                                          Instances For

                                                                            StateT wrapper around Tape.log. PyTorch comparison: torch.log.

                                                                            Instances For

                                                                              StateT wrapper around Tape.inv. PyTorch comparison: torch.reciprocal.

                                                                              Instances For

                                                                                StateT wrapper around Tape.safe_log (a numerically-stable log).

                                                                                Instances For

                                                                                  StateT wrapper around Tape.sum. PyTorch comparison: torch.sum.

                                                                                  Instances For

                                                                                    Run reverse-mode autodiff from a scalar output and return accumulated gradients.

                                                                                    This calls Tape.backwardScalar on the current tape and returns a HashMap from node ids to gradient tensors.

                                                                                    Instances For