TorchLean API

NN.Spec.Autograd.Ops

Autograd OpSpecs (spec layer) #

This file defines small OpSpec building blocks (forward + VJP) for common tensor operations. The definitions are intentionally direct mathematical contracts and live purely in the spec layer.

How to read this file:

Where this sits in TorchLean:

This file does not mirror every runtime method one-for-one. It is the reusable adapter layer for operations whose input-gradient VJP is naturally expressed as a single OpSpec. Larger multi-input/parameterized layers (convolution, attention, batchnorm, pooling, RNG) still have precise specs and runtime implementations, but their full backward state usually belongs in layer/runtime code rather than in this compact unary interface.

PyTorch analogy (roughly):

Elementwise lifting helpers #

def Spec.liftElementwise {α : Type} {s : Shape} (f : αα) :
Tensor α sTensor α s

Lift a scalar function to a tensor by pointwise map.

PyTorch analogy: most torch.* pointwise ops are vectorized elementwise maps.

Instances For
    def Spec.liftElementwiseBackward {α : Type} [Mul α] {s : Shape} (df : αα) :
    Tensor α sTensor α sTensor α s

    Lift an elementwise backward using the chain rule: dL/dx = df(x) * dL/dy pointwise.

    This is the standard VJP pattern for elementwise ops.

    PyTorch analogy: the "local backward" rule for a pointwise op multiplies by the derivative mask.

    Instances For
      def Spec.reluOp {α : Type} [Mul α] [One α] [Zero α] [Max α] [LT α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} :
      OpSpec α s s

      Elementwise ReLU OpSpec on any shape.

      PyTorch analogy: torch.relu(x) / torch.nn.functional.relu(x).

      Instances For
        def Spec.sigmoidOp {α : Type} [Context α] {s : Shape} :
        OpSpec α s s

        Elementwise sigmoid OpSpec on any shape.

        PyTorch analogy: torch.sigmoid(x).

        Instances For
          def Spec.tanhOp {α : Type} [Context α] {s : Shape} :
          OpSpec α s s

          Elementwise tanh OpSpec on any shape.

          PyTorch analogy: torch.tanh(x).

          Instances For
            def Spec.softplusOp {α : Type} [Context α] {s : Shape} :
            OpSpec α s s

            Elementwise softplus OpSpec on any shape.

            PyTorch analogy: torch.nn.functional.softplus(x).

            Instances For
              def Spec.swishOp {α : Type} [Context α] {s : Shape} :
              OpSpec α s s

              Elementwise Swish / SiLU OpSpec on any shape.

              PyTorch analogy: torch.nn.functional.silu(x).

              Instances For
                @[reducible, inline]
                abbrev Spec.siluOp {α : Type} [Context α] {s : Shape} :
                OpSpec α s s

                Alias for swishOp, using the common SiLU name.

                Instances For
                  def Spec.eluOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (eluAlpha : α) :
                  OpSpec α s s

                  Elementwise ELU OpSpec on any shape.

                  Instances For
                    def Spec.geluOp {α : Type} [Context α] [OfScientific α] {s : Shape} :
                    OpSpec α s s

                    Elementwise tanh-approximate GELU OpSpec on any shape.

                    PyTorch analogy: torch.nn.functional.gelu(x, approximate="tanh").

                    Instances For
                      def Spec.sinhOp {α : Type} [Context α] {s : Shape} :
                      OpSpec α s s

                      Elementwise hyperbolic sine OpSpec.

                      Instances For
                        def Spec.coshOp {α : Type} [Context α] {s : Shape} :
                        OpSpec α s s

                        Elementwise hyperbolic cosine OpSpec.

                        Instances For
                          def Spec.softmaxOp {α : Type} [Context α] {s : Shape} :
                          OpSpec α s s

                          Elementwise "softmax" OpSpec on any shape.

                          This is a true softmax along the last axis (applied independently over all outer slices).

                          PyTorch analogy: torch.softmax(x, dim=-1).

                          Instances For
                            def Spec.logSoftmaxOp {α : Type} [Context α] {s : Shape} :
                            OpSpec α s s

                            Stable last-axis log-softmax OpSpec.

                            Backward recomputes the forward output so the VJP matches logSoftmaxBackwardSpec. Runtime engines may cache that output instead.

                            Instances For

                              Linear layers #

                              def Spec.linearOp {α : Type} [Add α] [Mul α] [Zero α] [One α] {inDim outDim : } (m : LinearSpec α inDim outDim) :

                              Linear layer as an OpSpec: y = W x + b.

                              This OpSpec only returns the input gradient dL/dx. Parameter gradients for W and b are intentionally not part of OpSpec (those live at the graph/runtime level).

                              PyTorch analogy: torch.nn.Linear forward, with autograd producing grads for x/W/b.

                              Instances For
                                def Spec.scalarOf {α : Type} :

                                Extract scalar value from a scalar tensor.

                                We use this when an upstream gradient is a scalar (e.g. for reduced losses). In PyTorch this is the common pattern "loss is scalar, so grad_output is a scalar too".

                                Instances For
                                  @[reducible, inline]
                                  abbrev Spec.scalarValue {α : Type} :

                                  Alias for scalarOf (clarifies intent at call sites).

                                  Instances For
                                    def Spec.binaryElemOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) (f dfdx : ααα) :
                                    OpSpec α s s

                                    Generic elementwise binary OpSpec with captured right-hand tensor and d/dx.

                                    This is a "closure style" op: we treat the RHS tensor as a captured constant and only return the VJP with respect to the LHS input.

                                    PyTorch analogy: in a tape/graph, rhs is typically another node; here we are writing the "lhs-only" derivative for convenience.

                                    Instances For
                                      def Spec.scaleOp {α : Type} [Context α] {s : Shape} (c : α) :
                                      OpSpec α s s

                                      Scale by constant scalar.

                                      PyTorch analogy: x * c where c is a scalar constant.

                                      Instances For

                                        Unary elementwise ops #

                                        def Spec.negOp {α : Type} [Context α] {s : Shape} :
                                        OpSpec α s s

                                        Negation (-x).

                                        Instances For
                                          def Spec.absOp {α : Type} [Context α] {s : Shape} :
                                          OpSpec α s s

                                          Absolute value (uses signSpec for the subgradient).

                                          PyTorch analogy: torch.abs(x). At x = 0 we pick the subgradient 0.

                                          Instances For
                                            def Spec.smoothAbsOp {α : Type} [Context α] {s : Shape} (ε : α := Numbers.epsilon) :
                                            OpSpec α s s

                                            Smooth absolute value (a differentiable surrogate for abs).

                                            This is useful when you want to avoid a kink at 0 in optimization. PyTorch analogy: there is no single canonical smooth_abs, but it is similar in spirit to sqrt(x^2 + eps)-style smoothings.

                                            Instances For
                                              def Spec.expOp {α : Type} [Context α] {s : Shape} :
                                              OpSpec α s s

                                              Elementwise exp.

                                              PyTorch analogy: torch.exp(x).

                                              Instances For
                                                def Spec.logOp {α : Type} [Context α] {s : Shape} :
                                                OpSpec α s s

                                                Elementwise natural logarithm.

                                                Domain discipline: this is the raw mathematical/PyTorch-style rule. The VJP multiplies by 1/x, so callers should use it only when the input is strictly positive. Runtime backends are allowed to reject nonpositive inputs rather than silently manufacture a gradient. Use safeLogOp when the intended model is log(x + ε).

                                                PyTorch analogy: torch.log(x).

                                                Instances For
                                                  def Spec.safeLogOp {α : Type} [Context α] {s : Shape} (ε : α := Numbers.epsilon) :
                                                  OpSpec α s s

                                                  Elementwise log with epsilon shift, log(x + ε).

                                                  This is the default facade-safe logarithm: it is total as a spec expression and its VJP uses 1/(x+ε) pointwise.

                                                  PyTorch analogy: often written manually as torch.log(x + eps).

                                                  Instances For
                                                    def Spec.sqrtOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} :
                                                    OpSpec α s s

                                                    Elementwise square root.

                                                    Domain discipline: TorchLean's spec-level sqrtSpec is total by clamping the forward value on nonpositive inputs. The VJP follows that convention and returns zero where x <= 0, rather than introducing an artificial 1/ε spike.

                                                    PyTorch analogy: torch.sqrt(x) on the positive region, with an explicit TorchLean subgradient choice outside the classical domain.

                                                    Instances For
                                                      def Spec.squareOp {α : Type} [Context α] {s : Shape} :
                                                      OpSpec α s s

                                                      Elementwise square (x^2).

                                                      Instances For
                                                        def Spec.powOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                        OpSpec α s s

                                                        Elementwise power with a captured RHS exponent tensor.

                                                        This is the VJP with respect to the base x for x ^ rhs. Domain restrictions are the usual ones for the scalar backend's power operation.

                                                        Instances For
                                                          def Spec.invOp {α : Type} [Context α] {s : Shape} :
                                                          OpSpec α s s

                                                          Elementwise reciprocal, 1/x.

                                                          Domain discipline: this is the raw reciprocal. Its VJP is -1/x^2, so callers should use it only when zero is excluded by the surrounding invariant. Use safeInvOp when the intended model is 1/(x+ε).

                                                          PyTorch analogy: torch.reciprocal(x) or 1 / x.

                                                          Instances For
                                                            def Spec.safeInvOp {α : Type} [Context α] {s : Shape} :
                                                            OpSpec α s s

                                                            Elementwise epsilon-shifted reciprocal, 1/(x+ε).

                                                            This is the safe facade counterpart to invOp: the forward pass delegates to safedivSpec with unit numerator, and the VJP is the derivative of the same shifted expression.

                                                            PyTorch analogy: usually written manually as 1.0 / (x + eps).

                                                            Instances For

                                                              Binary ops capturing a right-hand tensor #

                                                              def Spec.addOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                              OpSpec α s s

                                                              Add a captured RHS tensor (x + rhs).

                                                              Instances For
                                                                def Spec.subOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                OpSpec α s s

                                                                Subtract a captured RHS tensor (x - rhs).

                                                                Instances For
                                                                  def Spec.mulOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                  OpSpec α s s

                                                                  Elementwise multiply by a captured RHS tensor.

                                                                  Instances For
                                                                    def Spec.divOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                    OpSpec α s s

                                                                    Elementwise divide by a captured RHS tensor.

                                                                    Domain discipline: this is the raw division rule. The VJP multiplies by 1/rhs, so callers should only use it when the captured denominator is known nonzero. Use safeDivOp when the intended model is x/(rhs+ε).

                                                                    Instances For
                                                                      def Spec.safeDivOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                      OpSpec α s s

                                                                      Elementwise safe division by a captured RHS tensor, x/(rhs+ε).

                                                                      PyTorch analogy: usually written manually as x / (rhs + eps).

                                                                      Instances For
                                                                        def Spec.minOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                        OpSpec α s s

                                                                        Elementwise min with captured RHS.

                                                                        We pick a subgradient via a <= mask (ties go to the left input).

                                                                        PyTorch analogy: torch.minimum(x, rhs) (subgradient convention is an implementation detail).

                                                                        Instances For
                                                                          def Spec.maxOp {α : Type} [Context α] {s : Shape} (rhs : Tensor α s) :
                                                                          OpSpec α s s

                                                                          Elementwise max with captured RHS.

                                                                          We pick a subgradient via a >= mask (ties go to the left input).

                                                                          PyTorch analogy: torch.maximum(x, rhs) (subgradient convention is an implementation detail).

                                                                          Instances For
                                                                            def Spec.leakyReluOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (αₗ : α) :
                                                                            OpSpec α s s

                                                                            Leaky ReLU with slope parameter.

                                                                            PyTorch analogy: torch.nn.functional.leaky_relu(x, negative_slope=alpha_l).

                                                                            Instances For
                                                                              def Spec.clampOp {α : Type} [Context α] {s : Shape} (minVal maxVal : α) :
                                                                              OpSpec α s s

                                                                              Clamp OpSpec with a fixed interval.

                                                                              We choose the standard subgradient 1 strictly inside the interval and 0 at/outside the boundaries, matching clampDerivativeSpec.

                                                                              Instances For

                                                                                Loss OpSpecs #

                                                                                def Spec.mseLossOp {α : Type} [Context α] {s : Shape} (target : Tensor α s) :

                                                                                MSE loss (returns a scalar), capturing the target.

                                                                                Instances For
                                                                                  def Spec.maeLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) :

                                                                                  MAE loss (returns a scalar), capturing the target.

                                                                                  Instances For
                                                                                    def Spec.huberLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) (delta : α := 1) :

                                                                                    Huber loss (returns a scalar), capturing the target.

                                                                                    Instances For
                                                                                      def Spec.crossEntropyLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) (epsilon : α := Numbers.epsilon) :

                                                                                      Cross-entropy loss (returns a scalar), capturing the target distribution.

                                                                                      This is "cross-entropy between distributions": target is p, yhat is q. PyTorch analogy: closer to -(p * log(q)).mean() than to the logits-based torch.nn.functional.cross_entropy default.

                                                                                      Instances For
                                                                                        def Spec.crossEntropyLogitsLossOp {α : Type} [Context α] {s : Shape} (target : Tensor α s) :

                                                                                        Logits-based cross-entropy loss, capturing the target distribution.

                                                                                        Instances For
                                                                                          def Spec.binaryCrossEntropyLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) (epsilon : α := Numbers.epsilon) :

                                                                                          Binary cross-entropy loss on probability tensors, capturing the target tensor.

                                                                                          Instances For
                                                                                            def Spec.cosineSimilarityLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) (epsilon : α := Numbers.epsilon) :

                                                                                            Cosine-similarity loss, capturing the target tensor.

                                                                                            Instances For
                                                                                              def Spec.hingeLossOp {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {s : Shape} (target : Tensor α s) :

                                                                                              Hinge loss (returns a scalar), capturing the target.

                                                                                              Instances For
                                                                                                def Spec.poissonLossOp {α : Type} [Context α] {s : Shape} (target : Tensor α s) :

                                                                                                Poisson loss (returns a scalar), capturing the target.

                                                                                                Instances For
                                                                                                  def Spec.logCoshLossOp {α : Type} [Context α] {s : Shape} (target : Tensor α s) :

                                                                                                  Log-cosh loss (returns a scalar), capturing the target.

                                                                                                  Instances For

                                                                                                    Normalization #

                                                                                                    def Spec.layerNormOp {α : Type} [Context α] (seqLen embedDim : ) (gamma beta : Tensor α (Shape.dim embedDim Shape.scalar)) (h_seq_pos : seqLen > 0 := by decide) (h_embed_pos : embedDim > 0 := by decide) :
                                                                                                    OpSpec α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar)) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                                                                    LayerNorm OpSpec over (seqLen × embedDim). Parameters gamma/beta are captured. Backward returns only ∂L/∂x (parameter grads are not returned at this level).

                                                                                                    Instances For
                                                                                                      def Spec.identityOp {α : Type} {s : Shape} :
                                                                                                      OpSpec α s s

                                                                                                      Identity op: pass-through forward and backward.

                                                                                                      Instances For

                                                                                                        Shape/structure ops #

                                                                                                        def Spec.reshapeOp {α : Type} {s t : Shape} [Inhabited α] (h : s.size = t.size) :
                                                                                                        OpSpec α s t

                                                                                                        Reshape op (requires a size-equality proof).

                                                                                                        PyTorch analogy: x.reshape(...) (or view), but here the shape relationship is explicit.

                                                                                                        Instances For

                                                                                                          Matrix transpose (2D) op.

                                                                                                          PyTorch analogy: x.transpose(0, 1) for a matrix.

                                                                                                          Instances For
                                                                                                            def Spec.constantOp {α : Type} [Context α] {s : Shape} (value : α) [Inhabited α] :
                                                                                                            OpSpec α s s

                                                                                                            Fill a tensor with a constant (ignores input).

                                                                                                            PyTorch analogy: torch.full_like(x, value) (but here we keep the input only to fit the OpSpec shape, and ignore its content).

                                                                                                            Instances For

                                                                                                              Replicate a scalar to any shape; backward sums gradients back to a scalar.

                                                                                                              PyTorch analogy: broadcasting a scalar in arithmetic, and in backward accumulating by sum.

                                                                                                              Instances For
                                                                                                                def Spec.applyMaskOp {α : Type} [Context α] {s : Shape} (mask : Tensor Bool s) :
                                                                                                                OpSpec α s s

                                                                                                                Apply boolean mask: keep where mask true, else set 0.

                                                                                                                PyTorch analogy: torch.where(mask, x, 0).

                                                                                                                Instances For
                                                                                                                  def Spec.dropoutInferenceOp {α : Type} [Context α] {s : Shape} (p : α) :
                                                                                                                  OpSpec α s s

                                                                                                                  Deterministic inference-style dropout scaling.

                                                                                                                  Instances For
                                                                                                                    def Spec.dropoutMaskedOp {α : Type} [Context α] {s : Shape} (p : α) (mask : Tensor Bool s) :
                                                                                                                    OpSpec α s s

                                                                                                                    Masked inverted-dropout OpSpec with an explicit mask.

                                                                                                                    Instances For

                                                                                                                      Right-multiply by fixed matrix: X (m×n) ↦ X·B (m×p).

                                                                                                                      Instances For

                                                                                                                        Left-multiply by fixed matrix: X (n×p) ↦ A·X (m×p).

                                                                                                                        Instances For
                                                                                                                          def Spec.bmmRightOp {α : Type} [Context α] {batch m n p : } (B : Tensor α (Shape.dim batch (Shape.dim n (Shape.dim p Shape.scalar)))) :

                                                                                                                          Batched matrix multiply with captured RHS: A ↦ A @ B.

                                                                                                                          Instances For
                                                                                                                            def Spec.bmmLeftOp {α : Type} [Context α] {batch m n p : } (A : Tensor α (Shape.dim batch (Shape.dim m (Shape.dim n Shape.scalar)))) :

                                                                                                                            Batched matrix multiply with captured LHS: B ↦ A @ B.

                                                                                                                            Instances For
                                                                                                                              def Spec.embeddingOnehotOp {α : Type} [Context α] {vocab embedDim seqLen : } (emb : EmbeddingSpec vocab embedDim α) :
                                                                                                                              OpSpec α (Shape.dim seqLen (Shape.dim vocab Shape.scalar)) (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

                                                                                                                              One-hot embedding as an OpSpec over the one-hot input. Parameter gradients stay outside OpSpec; this wrapper returns only dOneHot.

                                                                                                                              Instances For
                                                                                                                                def Spec.expandToColOp {α : Type} {n : } {s : Shape} :
                                                                                                                                OpSpec α (Shape.dim n s) (Shape.dim n (Shape.dim 1 s))

                                                                                                                                Expand vector to column (unsqueeze last dim size 1) and the inverse.

                                                                                                                                Instances For
                                                                                                                                  def Spec.squeezeColOp {α : Type} {n : } {s : Shape} :
                                                                                                                                  OpSpec α (Shape.dim n (Shape.dim 1 s)) (Shape.dim n s)

                                                                                                                                  Squeeze a trailing singleton dim (n,1,...) to (n,...) (adjoint unsqueezes).

                                                                                                                                  Instances For
                                                                                                                                    def Spec.concatDim0LeftOp {α : Type} {n m : } {s : Shape} (rhs : Tensor α (Shape.dim m s)) :
                                                                                                                                    OpSpec α (Shape.dim n s) (Shape.dim (n + m) s)

                                                                                                                                    Concatenate along the leading dimension with a captured RHS, returning the gradient slice for the LHS input.

                                                                                                                                    Instances For
                                                                                                                                      def Spec.concatDim0RightOp {α : Type} {n m : } {s : Shape} (lhs : Tensor α (Shape.dim n s)) :
                                                                                                                                      OpSpec α (Shape.dim m s) (Shape.dim (n + m) s)

                                                                                                                                      Concatenate along the leading dimension with a captured LHS, returning the gradient slice for the RHS input.

                                                                                                                                      Instances For
                                                                                                                                        def Spec.sliceRange0Op {α : Type} {n : } {s : Shape} [Zero α] (start len : ) (h : len + start n) :
                                                                                                                                        OpSpec α (Shape.dim n s) (Shape.dim len s)

                                                                                                                                        Slice a leading-axis range; backward inserts the upstream gradient into the original shape.

                                                                                                                                        Instances For

                                                                                                                                          Reductions and broadcasting #

                                                                                                                                          def Spec.reduceSumOp {α : Type} [Context α] {s : Shape} (axis : ) [valid : Shape.valid_axis_inst axis s] [wf : s.WellFormed] [Inhabited α] :

                                                                                                                                          Reduce-sum along axis using a valid_axis proof; backward broadcasts back.

                                                                                                                                          PyTorch analogy: torch.sum(x, dim=axis) (with keepdim=false).

                                                                                                                                          Instances For
                                                                                                                                            def Spec.binaryBroadcastOp {α : Type} [Context α] {s1 s2 t : Shape} [Inhabited α] (rhs : Tensor α s2) (cbx : s1.CanBroadcastTo t) (cby : s2.CanBroadcastTo t) (f dfdx : ααα) (reduce_back : Tensor α tTensor α s1) :
                                                                                                                                            OpSpec α s1 t

                                                                                                                                            Generic broadcasting-aware binary OpSpec.

                                                                                                                                            The caller supplies:

                                                                                                                                            • explicit broadcast proofs (CanBroadcastTo) for both sides, and
                                                                                                                                            • a reduce_back map that takes a gradient in the broadcasted shape t and reduces it back to the left shape s1.

                                                                                                                                            PyTorch analogy: this is where PyTorch's implicit broadcasting rules and reduction-of-broadcasted gradients ("sum over broadcasted dimensions") happen. In TorchLean we keep those shape relations explicit.

                                                                                                                                            Instances For
                                                                                                                                              def Spec.addBroadcastOp {α : Type} [Context α] {s1 s2 t : Shape} [Inhabited α] (rhs : Tensor α s2) (cbx : s1.CanBroadcastTo t) (cby : s2.CanBroadcastTo t) (reduce_back : Tensor α tTensor α s1) :
                                                                                                                                              OpSpec α s1 t

                                                                                                                                              Convenience: broadcasting-aware add with caller-provided reduction.

                                                                                                                                              Instances For
                                                                                                                                                def Spec.mulBroadcastOp {α : Type} [Context α] {s1 s2 t : Shape} [Inhabited α] (rhs : Tensor α s2) (cbx : s1.CanBroadcastTo t) (cby : s2.CanBroadcastTo t) (reduce_back : Tensor α tTensor α s1) :
                                                                                                                                                OpSpec α s1 t

                                                                                                                                                Convenience: broadcasting-aware mul with caller-provided reduction.

                                                                                                                                                Instances For