TorchLean API

NN.Proofs.RuntimeApprox.NF.ConvBackward.Common

Conv2D Backward Approximation #

NF (rounded) backend: Conv2D backward (VJP) runtime→spec approximation.

This file proves soundness of explicit bounds for the three Conv2D gradients computed by Spec.conv2d_backward_spec:

Each gradient has a different nested-indexing pattern, so the proof keeps the three bound families visible rather than hiding them behind one large opaque lemma. The important public objects are the tensor-level bounds (conv2d*BoundTensor), the approximation theorems (approxT_conv2d_*_deriv_spec), and conv2dRevNode, which packages Conv2D as a RevNode so it composes via RevGraph.backprop_approx.

PyTorch analogue: these are the VJP/gradient computations produced by Autograd for Conv2D. https://pytorch.org/docs/stable/autograd.html https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html

Map of this file #

References #

theorem Proofs.RuntimeApprox.NFBackend.specFold5_eq_threadFold5 {α : Type} [AddMonoid α] {outC out_h out_w kH kW : } (term : Fin outCFin out_hFin out_wFin kHFin kWα) :
have specFold := List.foldl (fun (acc : α) (out_ch : Fin outC) => acc + List.foldl (fun (acc : α) (out_i : Fin out_h) => acc + List.foldl (fun (acc : α) (out_j : Fin out_w) => acc + List.foldl (fun (acc : α) (di : Fin kH) => acc + List.foldl (fun (acc : α) (dj : Fin kW) => acc + term out_ch out_i out_j di dj) 0 (List.finRange kW)) 0 (List.finRange kH)) 0 (List.finRange out_w)) 0 (List.finRange out_h)) 0 (List.finRange outC); have threadFold := List.foldl (fun (accC : α) (out_ch : Fin outC) => List.foldl (fun (accH : α) (out_i : Fin out_h) => List.foldl (fun (accW : α) (out_j : Fin out_w) => List.foldl (fun (accKH : α) (di : Fin kH) => List.foldl (fun (accKW : α) (dj : Fin kW) => accKW + term out_ch out_i out_j di dj) accKH (List.finRange kW)) accW (List.finRange kH)) accH (List.finRange out_w)) accC (List.finRange out_h)) 0 (List.finRange outC); specFold = threadFold
def Proofs.RuntimeApprox.NFBackend.paddedInput {α : Type} [Context α] {inC inH inW padding : } (img : Spec.MultiChannelImage inC inH inW α) :
Spec.MultiChannelImage inC (inH + 2 * padding) (inW + 2 * padding) α

Padded-input helper used by the Conv2D spec: cast when padding = 0, otherwise padMultiChannel.

Instances For
    theorem Proofs.RuntimeApprox.NFBackend.get_at_or_zero_paddedInput {α : Type} [Context α] {inC inH inW padding : } (img : Spec.MultiChannelImage inC inH inW α) (c : Fin inC) (p q : ) :
    Spec.getAtOrZero (paddedInput img) [c, p, q] = if _h : p < padding q < padding then 0 else Spec.getAtOrZero img [c, p - padding, q - padding]
    theorem Proofs.RuntimeApprox.NFBackend.mkInputIdx_match_eq_paddedInput {α : Type} [Context α] {inC inH inW stride padding : } (img : Spec.MultiChannelImage inC inH inW α) (c : Fin inC) (oi di oj dj : ) :
    (match Spec.Private.mkInputIdx? [oi, oj] [di, dj] [stride, stride] [padding, padding] with | none => 0 | some inIdx => Spec.getAtOrZero img (c :: inIdx)) = Spec.getAtOrZero (paddedInput img) [c, oi * stride + di, oj * stride + dj]
    theorem Proofs.RuntimeApprox.NFBackend.conv2dKernelFoldRead_eq_paddedFold {α : Type} [Context α] {inC outC inH inW outH outW stride padding : } (input : Spec.MultiChannelImage inC inH inW α) (grad : Spec.MultiChannelImage outC outH outW α) (out_ch : Fin outC) (in_ch : Fin inC) (di dj : ) :
    List.foldl (fun (acc : α) (i : Fin outH) => List.foldl (fun (acc : α) (j : Fin outW) => acc + (match Spec.Private.mkInputIdx? [i, j] [di, dj] [stride, stride] [padding, padding] with | none => 0 | some inIdx => Spec.getAtOrZero input (in_ch :: inIdx)) * Spec.getAtOrZero grad [out_ch, i, j]) acc (List.finRange outW)) 0 (List.finRange outH) = List.foldl (fun (acc : α) (i : Fin outH) => List.foldl (fun (acc : α) (j : Fin outW) => acc + Spec.getAtOrZero (paddedInput input) [in_ch, i * stride + di, j * stride + dj] * Spec.getAtOrZero grad [out_ch, i, j]) acc (List.finRange outW)) 0 (List.finRange outH)
    theorem Proofs.RuntimeApprox.NFBackend.entry_eq_scalar_get_at_or_zero4 {α : Type} [Zero α] {n1 n2 n3 n4 : } (t : Spec.Tensor α (Spec.Shape.dim n1 (Spec.Shape.dim n2 (Spec.Shape.dim n3 (Spec.Shape.dim n4 Spec.Shape.scalar))))) (i1 : Fin n1) (i2 : Fin n2) (i3 : Fin n3) (i4 : Fin n4) :
    (match match match match t with | Spec.Tensor.dim f => f i1 with | Spec.Tensor.dim g => g i2 with | Spec.Tensor.dim h => h i3 with | Spec.Tensor.dim k => k i4) = Spec.Tensor.scalar (Spec.getAtOrZero t [i1, i2, i3, i4])