TorchLean API

NN.Proofs.Autograd.Tape.Ops.Conv.BackwardDot

BackwardDot #

Conv2D backward dot-level correctness (bridge lemma).

The runtime autograd engine (NN.Runtime.Autograd.Engine) computes Conv2D gradients via the handwritten spec:

For the analytic spec-level theorem over , Conv2D is already covered via fderiv/adjoints in:

That file provides a proof-only Conv2D node whose VJP is (fderiv forward)†, so any DAG using it is covered by the global theorem Graph.backpropVec_eq_adjoint_fderiv.

What remains here is the dot/adjointness bridge:

dot (JVP_conv2d …) δ = dot dKernel gK + dot dBias gB + dot dInput gX

where (gK, gB, gX) = Spec.conv2d_backward_spec … δ.

The padding-related rewrites needed for the input-gradient proof are factored into:

Broadcast a bias gradient outC across spatial axes into an outC×outH×outW tensor.

Instances For

    paddedInput matches the forward spec’s padding convention:

    PyTorch analogue: torch.nn.functional.pad (or implicit padding inside conv2d). https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html

    noncomputable def Proofs.Autograd.Conv2D.paddedInput {inC inH inW padding : } (input : Spec.MultiChannelImage inC inH inW ) :
    Spec.MultiChannelImage inC (inH + 2 * padding) (inW + 2 * padding)
    Instances For
      theorem Proofs.Autograd.Conv2D.get_at_or_zero_paddedInput {inC inH inW padding : } (img : Spec.MultiChannelImage inC inH inW ) (c : Fin inC) (p q : ) :
      Spec.getAtOrZero (paddedInput img) [c, p, q] = if _h : p < padding q < padding then 0 else Spec.getAtOrZero img [c, p - padding, q - padding]

      Output shape helpers (no dilation): these are the standard “convolution arithmetic” formulas. They are kept as definitions so later statements can share the same expression.

      def Proofs.Autograd.Conv2D.outH (inH kH stride padding : ) :
      Instances For
        def Proofs.Autograd.Conv2D.outW (inW kW stride padding : ) :

        Output width helper (no dilation).

        Instances For
          theorem Proofs.Autograd.Conv2D.conv2d_spec_noBias_get {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (dKernel : Spec.Tensor (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (input : Spec.MultiChannelImage inC inH inW ) (oc : Fin outC) (i : Fin (outH inH kH stride padding)) (j : Fin (outW inW kW stride padding)) :
          have layerK := { kernel := dKernel, bias := Spec.fill 0 (Spec.Shape.dim outC Spec.Shape.scalar) }; Spec.getAtOrZero (Spec.conv2dSpec layerK input) [oc, i, j] = ic : Fin inC, di : Fin kH, dj : Fin kW, Spec.getAtOrZero dKernel [oc, ic, di, dj] * Spec.getAtOrZero (paddedInput input) [ic, i * stride + di, j * stride + dj]
          theorem Proofs.Autograd.Conv2D.conv2d_backward_spec_dot {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (layer : Spec.Conv2DSpec inC outC kH kW stride padding h1 h2 h3) (input : Spec.MultiChannelImage inC inH inW ) (δ : Spec.MultiChannelImage outC (outH inH kH stride padding) (outW inW kW stride padding) ) (dKernel : Spec.Tensor (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (dBias : Spec.Tensor (Spec.Shape.dim outC Spec.Shape.scalar)) (dInput : Spec.MultiChannelImage inC inH inW ) :
          have layerK := { kernel := dKernel, bias := Spec.fill 0 (Spec.Shape.dim outC Spec.Shape.scalar) }; have layer0 := { kernel := layer.kernel, bias := Spec.fill 0 (Spec.Shape.dim outC Spec.Shape.scalar) }; have jvp := Spec.Tensor.addSpec (Spec.conv2dSpec layerK input) ((biasBroadcast dBias).addSpec (Spec.conv2dSpec layer0 dInput)); have grads := Spec.conv2dBackwardSpec layer input δ; Spec.dot jvp δ = Spec.dot dKernel grads.1 + Spec.dot dBias grads.2.1 + Spec.dot dInput grads.2.2

          Main dot-level bridge theorem for Conv2D.

          It states that the triple returned by Spec.conv2d_backward_spec is the adjoint (w.r.t. Spec.dot) of the corresponding forward-mode directional derivatives with respect to (kernel, bias, input).

          This is the key lemma connecting the handwritten runtime backward to the analytic “VJP is adjoint of fderiv” theorem in NN.Proofs.Autograd.Tape.Ops.Conv.FDeriv.