TorchLean API

NN.Runtime.Autograd.Compiled.GraphM.Convolution

GraphM Convolution Ops #

N-dimensional and two-dimensional convolution and transposed-convolution builders.

def Runtime.Autograd.Compiled.GraphM.conv {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : Var (Spec.Shape.ofList (outC :: inC :: kernel.toList))) (b : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (x : Var (Spec.Shape.ofList (inC :: inSpatial.toList))) :
MWith α Δ Γ (Var (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))

N-dimensional convolution (channels-first) on a single sample tensor.

The input shape is (inC, spatial...), the kernel shape is (outC, inC, kernelSpatial...), and the bias shape is (outC). The output spatial sizes use the PyTorch-style floor-division formula.

The JVP follows bilinearity: d(conv(k,b,x)) = conv(k,0,dx) + conv(dk,db,x).

Instances For
    def Runtime.Autograd.Compiled.GraphM.convTranspose {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {d inC outC : } {kernel stride padding inSpatial : Vector d} {hInC : inC 0} {hKernel : ∀ (i : Fin d), kernel.get i 0} (w : Var (Spec.Shape.ofList (inC :: outC :: kernel.toList))) (b : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (x : Var (Spec.Shape.ofList (inC :: inSpatial.toList))) :
    MWith α Δ Γ (Var (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))

    N-D transpose convolution (channels-first) on a single sample tensor (no batch axis).

    Conventions:

    • input shape is (inC, spatial...),
    • kernel shape is (inC, outC, kernelSpatial...) (PyTorch layout),
    • bias shape is (outC),
    • output spatial sizes use: out[a] = (in[a] - 1) * stride[a] - 2*padding[a] + kernel[a] (with output_padding = 0).

    PyTorch comparison: torch.nn.functional.conv_transpose{d}d, specialized to a single sample.

    Forward-mode JVP uses bilinearity: d(convTranspose(k,b,x)) = convTranspose(k,0,dx) + convTranspose(dk,db,x).

    Instances For
      def Runtime.Autograd.Compiled.GraphM.conv2d {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Var (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
      MWith α Δ Γ (Var (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))

      2D convolution (channel-first) on a single image tensor.

      PyTorch comparison: torch.nn.functional.conv2d (without a batch dimension).

      Forward-mode JVP uses bilinearity: d(conv2d(k,b,x)) = conv2d(k,0,dx) + conv2d(dk,db,x).

      Instances For
        def Runtime.Autograd.Compiled.GraphM.convTranspose2d {α Δ : Type} [Context α] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {inC outC kH kW stride padding inH inW : } {h1 : inC 0} {h2 : kH 0} {h3 : kW 0} (kernel : Var (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))) (bias : Var (Spec.Shape.dim outC Spec.Shape.scalar)) (input : Var (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar)))) :
        MWith α Δ Γ (Var (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))

        2D transpose convolution (channel-first) on a single image tensor.

        PyTorch comparison: torch.nn.functional.conv_transpose2d (without a batch dimension).

        Forward-mode JVP uses bilinearity: d(convTranspose2d(k,b,x)) = convTranspose2d(k,0,dx) + convTranspose2d(dk,db,x).

        Instances For