TorchLean API

NN.Spec.Autograd.AutogradSpec

Spec-level autograd operation specifications (Spec.OpSpec) #

This file defines a small interface for reverse-mode differentiation:

This is intentionally a spec-layer interface: it is independent of any particular runtime autograd engine or tape representation. The runtime code is free to:

as long as it implements the same mathematical VJP behavior.

PyTorch analogy:

In TorchLean we deliberately keep this file smaller than a graph IR: OpSpec is the math contract (forward + VJP + composition). We do not want to invent yet another graph/IR here, because the repo already has canonical graph representations:

When you want "a real graph", use those. When you want "the spec of an op", use OpSpec.

structure Spec.OpSpec (α : Type) (σ τ : Shape) :

Atomic operation specification (forward + VJP/backward).

backward takes the input x and an upstream gradient dL/dy, and returns dL/dx.

Why this signature:

  • Reverse-mode AD never needs a full Jacobian. What it needs is: given an upstream gradient dL/dy, compute the gradient with respect to the input, dL/dx. That’s exactly what backward encodes (a VJP).
  • We pass x to backward because many derivatives depend on the input value. At the spec level we don’t force a “store intermediates vs recompute” strategy; the runtime system can choose.
Instances For
    def Spec.OpSpec.id (α : Type) (σ : Shape) :
    OpSpec α σ σ

    The identity OpSpec (forward is identity; backward returns the upstream gradient).

    Instances For
      def Spec.OpSpec.compose {α : Type} {σ τ υ : Shape} (f : OpSpec α σ τ) (g : OpSpec α τ υ) :
      OpSpec α σ υ

      Sequential composition of two ops with the reverse-mode chain rule.

      If f : σ → τ and g : τ → υ, their composition is g ∘ f : σ → υ.

      For reverse-mode AD, we compose their VJPs:

      • given an upstream gradient dL/dz : υ,
      • compute dL/dy : τ using g.backward,
      • then compute dL/dx : σ using f.backward.

      You can visualize the dataflow as a compact chain:

      x --f.forward--> y --g.forward--> z
      

      and the reverse pass as the same chain walked backwards:

      dL/dz --g.backward--> dL/dy --f.backward--> dL/dx
      

      This is the core of what PyTorch builds dynamically as a "backward graph" during the forward pass, except here we keep it as an explicit, pure definition. A runtime engine can still choose whether to cache the intermediate y = f.forward x or recompute it; the spec states the mathematical VJP.

      Instances For