CUDA Conv/Pool FFI #
Foreign-function declarations for TorchLean's float32 convolution and pooling kernels. The real
CUDA implementation lives in csrc/cuda/conv_pool/; CPU stubs with the same symbols are used when
TorchLean is built without -K cuda=true.
All buffers are contiguous Cuda.Buffer values and shape/stride/padding metadata is passed
explicitly through the FFI boundary.
Float32 conv2d forward (device Buffer inputs/outputs).
Float32 conv-transpose2d forward (device Buffer inputs/outputs).
Float32 conv-transpose2d backward: returns (dKernel, dBias, dInput) device buffers.
Float32 N-D transposed convolution forward (channels-first, no batch).
Shapes/parameters:
inSpatial: lengthd(input spatial dims)kernelSpatial: lengthd(kernel window)stride: lengthdpadding: lengthd
All arrays must have the same length d ≤ 8.
Layout conventions:
- input:
(inC, spatial...) - kernel:
(inC, outC, kernelSpatial...) - bias:
(outC) - output:
(outC, outSpatial...), whereoutSpatial[i] = (inSpatial[i] - 1) * stride[i] - 2*padding[i] + kernelSpatial[i].
Float32 N-D transposed convolution backward.
Returns (dKernel, dBias, dInput) as device buffers.
Array conventions match torchleanConvTransposeFwdCuda.
Float32 N-D convolution forward (channels-first, no batch).
Shapes/parameters:
inSpatial: lengthd(spatial dims)kernelSpatial: lengthd(kernel window)stride: lengthdpadding: lengthd
All arrays must have the same length d ≤ 8.
Float32 N-D convolution backward.
Returns (dKernel, dBias, dInput) as device buffers.
Array conventions match torchleanConvFwdCuda.
Float32 max-pool2d forward (channels preserved).
Float32 max-pool2d backward: returns dInput.
Float32 avg-pool2d forward (channels preserved).
Float32 avg-pool2d backward: returns dInput.
Float32 smooth max-pool2d (log-sum-exp surrogate) forward.
This matches Spec.smooth_max_pool2d_spec for Float:
y = log(sum(exp(beta*x))) / beta computed per window, with beta ≠ 0.
Float32 smooth max-pool2d backward: returns dInput.
VJP matches Spec.smooth_max_pool2d_backward_spec for Float:
dx += dOut * exp(beta*x)/sum(exp(beta*x)) within each window.