CUDA Buffer Kernels FFI #
Foreign-function declarations for TorchLean's float32 Cuda.Buffer kernels: reductions, indexing,
matmul/BMM, attention, broadcast/view helpers, and related tensor operations. The declarations here
are the Lean side of the explicit CUDA trust boundary documented in TRUST_BOUNDARIES.md.
Sum over axis 0 of a 2D buffer in row-major order.
Input b has shape (rows, cols) and is stored as length rows*cols.
Output is length cols (sum down the rows for each column).
Sum over axis 1 of a 2D buffer in row-major order.
Input b has shape (rows, cols) and is stored as length rows*cols.
Output is length rows (sum across the columns for each row).
Max over axis 0 of a 2D buffer in row-major order.
Input b has shape (rows, cols) and is stored as length rows*cols.
Output is length cols (max down the rows for each column).
Max over axis 1 of a 2D buffer in row-major order.
Input b has shape (rows, cols) and is stored as length rows*cols.
Output is length rows (max across the columns for each row).
Concatenate two 1D buffers a (length n) and b (length m).
Slice a 1D buffer b (length n) starting at start for len elements.
Requires start + len ≤ n.
Broadcast a row-vector (length cols) to a (rows, cols) matrix.
Output is row-major of length rows*cols, with out[i, j] = vec[j].
Broadcast a column-vector (length rows) to a (rows, cols) matrix.
Output is row-major of length rows*cols, with out[i, j] = vec[i].
Batched matrix multiply over row-major buffers.
Input:
A: lengthbatch*m*nrepresentingbatchmatrices of shape(m, n)(row-major)B: lengthbatch*n*prepresentingbatchmatrices of shape(n, p)(row-major) Output:- length
batch*m*prepresentingbatchmatrices of shape(m, p)(row-major)
Real-valued 1D FFT over row-major batches, returning a packed half-spectrum.
Input:
x: lengthbatch*n, interpreted as shape(batch, n).
Output:
- length
batch*(n/2+1)*2, interpreted as shape(batch, n/2+1, 2); - the last channel stores
[real, imag]for each nonredundant frequency bin.
CUDA uses cuFFT R2C under the hood. The CPU stub uses a direct reference DFT, so this primitive
remains available in non-CUDA builds for tests and portability. This is a low-level runtime
primitive; differentiable tensor/autograd wrappers should spell out their backward convention
separately because half-spectrum packing has normalization and conjugate-symmetry edge cases.
Inverse of rfft1dPacked for packed half-spectra.
Input:
spec: lengthbatch*(n/2+1)*2, interpreted as(batch, n/2+1, 2).
Output:
- length
batch*n, interpreted as(batch, n).
The CUDA implementation uses cuFFT C2R and explicitly scales by 1/n, matching the CPU reference
and the usual normalized inverse FFT convention used by high-level ML APIs.
Fused real-FFT spectral convolution for one FNO1D block.
Input:
x: lengthgrid*width, row-major shape(grid, width);wRe,wIm: lengthmodes*width*width, row-major shape(modes, width, width).
Semantics:
- apply an unnormalized real FFT along the grid axis for each input channel,
- keep frequency bins
0 ≤ k < modes, - multiply each retained complex vector by
wRe[k] + i*wIm[k], - zero all other bins,
- apply the normalized inverse real FFT.
This is the CUDA/cuFFT-backed runtime primitive intended to replace dense DFT matrix multiplies in float32 FNO examples. The three backward primitives below are its explicit VJP components.
VJP component ∂L/∂x for spectralConv1dRfftFwd.
VJP component ∂L/∂wRe for spectralConv1dRfftFwd.
VJP component ∂L/∂wIm for spectralConv1dRfftFwd.
Diagonal selective-scan forward kernel for state-space models.
Inputs:
A,B,h0: lengthstate, representing per-channel recurrence parameters and initial state,X: lengthseqLen*state, row-major token/state inputs.
Output:
- length
seqLen*state, row-major hidden states, withh[t,j] = A[j] * h[t-1,j] + B[j] * X[t,j], starting fromh0[j].
This is the runtime primitive corresponding to the proof-facing affine scan contract in
NN.Spec.Layers.SelectiveScan and NN.MLTheory.Proofs.StateSpace.Scan.
Backward kernel for selectiveScanDiagFwd.
Given out = selectiveScanDiagFwd A B X h0 and an upstream gradient dY with the same
seqLen*state layout as out, returns (dA, dB, dX, dH0).
Diagonal selective-scan forward kernel with token-dependent coefficients.
Inputs:
A,B,X: lengthseqLen*state, row-major by(time, flattened_state_channel),h0: lengthstate.
Output:
- length
seqLen*state, withh[t,j] = A[t,j] * h[t-1,j] + B[t,j] * X[t,j].
This is the runtime primitive corresponding to full Mamba-style selective scans where the token controls the affine transition coefficients.
Native fused scaled dot-product attention forward over split attention heads.
Inputs are row-major buffers with shapes:
Q,K,V:(batch, n, d), wherebatchis usually the number of heads,mask:(batch, n, n)encoded as0.0/1.0whenhasMask != 0; otherwise ignored.
Output has shape (batch, n, d) and computes the same no-dropout masked attention semantics as:
softmax((Q Kᵀ) * scale + maskFill) V, where blocked mask entries use TorchLean's
-1000.0 fill convention.
This is a fused native runtime primitive, not a proof object. The proof-facing contract is
Spec.flashAttention in NN/Spec/Layers/FlashAttention.lean.
Fused VJP component ∂L/∂Q for flashAttentionFwd.
Fused VJP component ∂L/∂K for flashAttentionFwd.
Fused VJP component ∂L/∂V for flashAttentionFwd.
Row-major transpose of a 2D buffer.
Input b has shape (rows, cols) and is stored as length rows*cols.
Output has shape (cols, rows) and is stored as length rows*cols (row-major).
Gather k scalars from a 1D vector using host indices.
Input:
Indices that fit in UInt32 but are out of bounds are totalized to 0.
Large Nat values outside the FFI index range are rejected by the runtime.
Broadcast a buffer to a new shape (TorchLean Shape.CanBroadcastTo semantics).
Arguments:
x: input bufferinDims: input dimension list (outermost-first)outDims: output dimension list (outermost-first)axisMap: lengthoutDims.size;axisMap[j] = 0means the output axisjis an inserted/broadcast axis (input coordinate is0), otherwiseaxisMap[j] = inAxis+1tells which input axis to read.
This shape-driven mapping is generated in Lean from a Shape.CanBroadcastTo proof so the kernel
does not need to interpret the proof object.
Adjoint of broadcastTo for sum-accumulation: reduce a broadcasted gradient back to the input
shape by summing over broadcasted axes.
This uses the same (inDims,outDims,axisMap) convention as broadcastTo.
Swap adjacent axes at depth for a contiguous buffer described by dims.
depth = 0 swaps the first two axes; depth = 1 swaps axes 1 and 2; etc.
Reduce-sum along axis for an N-D contiguous buffer described by dims (outermost-first).
The returned buffer is laid out row-major with shape dims with the axis dimension removed.
Gather k rows from a row-major matrix.
Input:
Output:
- shape
(k, cols)stored row-major as lengthk*cols
Indices that fit in UInt32 but are out of bounds are totalized to 0 rows.
Large Nat values outside the FFI index range are rejected by the runtime.
Scatter-add into a single matrix row.
Returns a copy of mat with out[i,:] += rowVec.
Scatter-add k rows given host indices.
Semantics: out = mat with out[indices[r], j] += values[r, j] for each r < k, j < cols.
Indices that fit in UInt32 but are out of bounds are ignored; repeated indices accumulate
(scatter-add). Large Nat values outside the FFI index range are rejected by the runtime.