Backend-Generic Functional API #
The Ops interface and curried helper syntax used to write one model once and run it on either the
eager runtime or the compiled graph backend.
Append two TLists.
This is a small utility for bridging between curried APIs and list-of-shapes APIs.
Instances For
Split a TList α (ss₁ ++ ss₂) into its left and right parts.
This is the inverse of TList.append.
Instances For
Type of a curried function accepting one tensor argument per shape in ss.
For example, Fn α [s₁, s₂] β is Tensor α s₁ → Tensor α s₂ → β.
Instances For
Convert a function on TList inputs into its curried form.
Instances For
Convert a curried function into a function on TList inputs.
Instances For
RefList is the reference-analogue of TList: a heterogeneous list of Ref s values indexed by
a shape list.
This is used to write backend-generic code over references (e.g. TensorRefs in eager mode, or
GraphM.Vars in compiled mode).
Reference-analogue of TList: a heterogeneous list of Ref s values indexed by shapes.
- nil {Ref : Spec.Shape → Type} : RefList Ref []
- cons {Ref : Spec.Shape → Type} {s : Spec.Shape} {ss : List Spec.Shape} : Ref s → RefList Ref ss → RefList Ref (s :: ss)
Instances For
Append two RefLists.
Instances For
Split a RefList Ref (ss₁ ++ ss₂) into its left and right parts.
Instances For
Split a RefList Ref (ss ++ [τ]) into its prefix and last element.
Instances For
Type of a curried function over references, one Ref s argument per shape in ss.
This mirrors Curried.Fn, but for Ref-valued arguments (e.g. TensorRefs in eager mode or
GraphM.Vars in compiled mode).
Instances For
Uncurry a curried reference function to accept a RefList.
Instances For
Curry a reference function that consumes a RefList.
Instances For
Apply a curried reference function to a GraphM.VarList.
This is a convenience for the compiled backend, where leaves/inputs are represented as Vars.
Instances For
Backend-generic interface for building and executing tensor programs.
This typeclass lets you write a single model/loss once (polymorphic over Ops m α) and then choose:
- an eager backend that executes immediately on a runtime tape, or
- a compiled backend that records proved IR (
GraphM) for later compilation/proofs.
Each method corresponds to a Tensor op; implementations are expected to match the semantics of the
corresponding Runtime.Autograd.Tape.* / Compiled.GraphM.* operator.
- Ref : Spec.Shape → Type
- const {s : Spec.Shape} : Spec.Tensor α s → m (Ref m α s)
- add {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α s)
- sub {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α s)
- mul {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α s)
- scale {s : Spec.Shape} : Ref m α s → α → m (Ref m α s)
- abs {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- sqrt {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- clamp {s : Spec.Shape} : Ref m α s → α → α → m (Ref m α s)
- max {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α s)
- min {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α s)
- broadcastTo {s₁ s₂ : Spec.Shape} : s₁.CanBroadcastTo s₂ → Ref m α s₁ → m (Ref m α s₂)
- transpose2d {mDim nDim : ℕ} : Ref m α (Spec.Shape.dim mDim (Spec.Shape.dim nDim Spec.Shape.scalar)) → m (Ref m α (Spec.Shape.dim nDim (Spec.Shape.dim mDim Spec.Shape.scalar)))
- transpose3dFirstToLast {a b c : ℕ} : Ref m α (Spec.Shape.dim a (Spec.Shape.dim b (Spec.Shape.dim c Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim b (Spec.Shape.dim c (Spec.Shape.dim a Spec.Shape.scalar))))
- transpose3dLastToFirst {a b c : ℕ} : Ref m α (Spec.Shape.dim a (Spec.Shape.dim b (Spec.Shape.dim c Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim c (Spec.Shape.dim a (Spec.Shape.dim b Spec.Shape.scalar))))
- transpose3dLastTwo {a b c : ℕ} : Ref m α (Spec.Shape.dim a (Spec.Shape.dim b (Spec.Shape.dim c Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim a (Spec.Shape.dim c (Spec.Shape.dim b Spec.Shape.scalar))))
- swapAdjacentAtDepth {s : Spec.Shape} (depth : ℕ) : Ref m α s → m (Ref m α (s.swapAdjacentAtDepth depth))
- reduceSum {s : Spec.Shape} (axis : ℕ) [Spec.Shape.valid_axis_inst axis s] [s.WellFormed] : Ref m α s → m (Ref m α (Spec.Tensor.shapeAfterSum s axis))
- reduceMean {s : Spec.Shape} (axis : ℕ) [Spec.Shape.valid_axis_inst axis s] [s.WellFormed] : Ref m α s → m (Ref m α (Spec.Tensor.shapeAfterSum s axis))
- gatherScalar {n : ℕ} : Ref m α (Spec.Shape.dim n Spec.Shape.scalar) → Fin n → m (Ref m α Spec.Shape.scalar)
- gatherRow {rows cols : ℕ} : Ref m α (Spec.Shape.dim rows (Spec.Shape.dim cols Spec.Shape.scalar)) → Fin rows → m (Ref m α (Spec.Shape.dim cols Spec.Shape.scalar))
- gatherScalarNat {n : ℕ} : Ref m α (Spec.Shape.dim n Spec.Shape.scalar) → ℕ → m (Ref m α Spec.Shape.scalar)
- gatherVecNat {n k : ℕ} : Ref m α (Spec.Shape.dim n Spec.Shape.scalar) → Spec.Tensor ℕ (Spec.Shape.dim k Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim k Spec.Shape.scalar))
- gatherRowsNat {rows cols k : ℕ} : Ref m α (Spec.Shape.dim rows (Spec.Shape.dim cols Spec.Shape.scalar)) → Spec.Tensor ℕ (Spec.Shape.dim k Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim k (Spec.Shape.dim cols Spec.Shape.scalar)))
- scatterAddVec {n : ℕ} : Ref m α (Spec.Shape.dim n Spec.Shape.scalar) → Ref m α Spec.Shape.scalar → Fin n → m (Ref m α (Spec.Shape.dim n Spec.Shape.scalar))
- scatterAddRow {rows cols : ℕ} : Ref m α (Spec.Shape.dim rows (Spec.Shape.dim cols Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim cols Spec.Shape.scalar) → Fin rows → m (Ref m α (Spec.Shape.dim rows (Spec.Shape.dim cols Spec.Shape.scalar)))
- matmul {mDim nDim pDim : ℕ} : Ref m α (Spec.Shape.dim mDim (Spec.Shape.dim nDim Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim nDim (Spec.Shape.dim pDim Spec.Shape.scalar)) → m (Ref m α (Spec.Shape.dim mDim (Spec.Shape.dim pDim Spec.Shape.scalar)))
- bmm {batch mDim nDim pDim : ℕ} : Ref m α (Spec.Shape.dim batch (Spec.Shape.dim mDim (Spec.Shape.dim nDim Spec.Shape.scalar))) → Ref m α (Spec.Shape.dim batch (Spec.Shape.dim nDim (Spec.Shape.dim pDim Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim batch (Spec.Shape.dim mDim (Spec.Shape.dim pDim Spec.Shape.scalar))))
- concatVectors {nDim mDim : ℕ} : Ref m α (Spec.Shape.dim nDim Spec.Shape.scalar) → Ref m α (Spec.Shape.dim mDim Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim (nDim + mDim) Spec.Shape.scalar))
- concatDim0 {nDim mDim : ℕ} {s : Spec.Shape} : Ref m α (Spec.Shape.dim nDim s) → Ref m α (Spec.Shape.dim mDim s) → m (Ref m α (Spec.Shape.dim (nDim + mDim) s))
- sliceRange0 {nDim : ℕ} {s : Spec.Shape} (start len : ℕ) (h : len + start ≤ nDim) : Ref m α (Spec.Shape.dim nDim s) → m (Ref m α (Spec.Shape.dim len s))
- maxPool {d C : ℕ} {inSpatial kernel stride padding : Vector ℕ d} {hKernel : ∀ (i : Fin d), kernel.get i ≠ 0} : Ref m α (Spec.Shape.ofList (C :: inSpatial.toList)) → m (Ref m α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))
- avgPool {d C : ℕ} {inSpatial kernel stride padding : Vector ℕ d} (hKernel : ∀ (i : Fin d), kernel.get i ≠ 0) : Ref m α (Spec.Shape.ofList (C :: inSpatial.toList)) → m (Ref m α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))
- smoothMaxPool {d C : ℕ} {inSpatial kernel stride padding : Vector ℕ d} {hKernel : ∀ (i : Fin d), kernel.get i ≠ 0} : Ref m α (Spec.Shape.ofList (C :: inSpatial.toList)) → α → m (Ref m α (Spec.Shape.ofList (C :: (Spec.poolOutSpatialPad inSpatial kernel stride padding).toList)))
- maxPool2d {kH kW inH inW inC stride : ℕ} {h1 : kH ≠ 0} {h2 : kW ≠ 0} : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))
- maxPool2dPad {kH kW inH inW inC stride padding : ℕ} {h1 : kH ≠ 0} {h2 : kW ≠ 0} : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))
- smoothMaxPool2d {kH kW inH inW inC stride : ℕ} {h1 : kH ≠ 0} {h2 : kW ≠ 0} : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → α → m (Ref m α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))
- avgPool2d {kH kW inH inW inC stride : ℕ} (h1 : kH ≠ 0) (h2 : kW ≠ 0) : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim inC (Spec.Shape.dim ((inH - kH) / stride + 1) (Spec.Shape.dim ((inW - kW) / stride + 1) Spec.Shape.scalar))))
- avgPool2dPad {kH kW inH inW inC stride padding : ℕ} (h1 : kH ≠ 0) (h2 : kW ≠ 0) : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim inC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))
- relu {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- sigmoid {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- tanh {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- softmax {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- logSoftmax {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- softplus {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- exp {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- log {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- inv {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- detach {s : Spec.Shape} : Ref m α s → m (Ref m α s)
- safeLog {s : Spec.Shape} : Ref m α s → α → m (Ref m α s)
- sum {s : Spec.Shape} : Ref m α s → m (Ref m α Spec.Shape.scalar)
- flatten {s : Spec.Shape} : Ref m α s → m (Ref m α (Spec.Shape.dim s.size Spec.Shape.scalar))
- linear {inDim outDim : ℕ} : Ref m α (Spec.Shape.dim outDim (Spec.Shape.dim inDim Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim outDim Spec.Shape.scalar) → Ref m α (Spec.Shape.dim inDim Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim outDim Spec.Shape.scalar))
- mseLoss {s : Spec.Shape} : Ref m α s → Ref m α s → m (Ref m α Spec.Shape.scalar)
- layerNorm {seqLen embedDim : ℕ} (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) : Ref m α (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim embedDim Spec.Shape.scalar) → Ref m α (Spec.Shape.dim embedDim Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar)))
- batchnormChannelFirst {channels height width : ℕ} (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) : Ref m α (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar))) → Ref m α (Spec.Shape.dim channels Spec.Shape.scalar) → Ref m α (Spec.Shape.dim channels Spec.Shape.scalar) → m (Ref m α (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar))))
- multiHeadAttention {n numHeads dModel headDim : ℕ} (h1 : n ≠ 0) : Ref m α (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar)) → Ref m α (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar)) → Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar)))
- conv {d inC outC : ℕ} {kernel stride padding inSpatial : Vector ℕ d} {hInC : inC ≠ 0} {hKernel : ∀ (i : Fin d), kernel.get i ≠ 0} : Ref m α (Spec.Shape.ofList (outC :: inC :: kernel.toList)) → Ref m α (Spec.Shape.dim outC Spec.Shape.scalar) → Ref m α (Spec.Shape.ofList (inC :: inSpatial.toList)) → m (Ref m α (Spec.Shape.ofList (outC :: (Spec.convOutSpatial inSpatial kernel stride padding).toList)))
- convTranspose {d inC outC : ℕ} {kernel stride padding inSpatial : Vector ℕ d} {hInC : inC ≠ 0} {hKernel : ∀ (i : Fin d), kernel.get i ≠ 0} : Ref m α (Spec.Shape.ofList (inC :: outC :: kernel.toList)) → Ref m α (Spec.Shape.dim outC Spec.Shape.scalar) → Ref m α (Spec.Shape.ofList (inC :: inSpatial.toList)) → m (Ref m α (Spec.Shape.ofList (outC :: (Spec.convTransposeOutSpatial inSpatial kernel stride padding).toList)))
- conv2d {inC outC kH kW stride padding inH inW : ℕ} {h1 : inC ≠ 0} {h2 : kH ≠ 0} {h3 : kW ≠ 0} : Ref m α (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar)))) → Ref m α (Spec.Shape.dim outC Spec.Shape.scalar) → Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim outC (Spec.Shape.dim ((inH + 2 * padding - kH) / stride + 1) (Spec.Shape.dim ((inW + 2 * padding - kW) / stride + 1) Spec.Shape.scalar))))
- convTranspose2d {inC outC kH kW stride padding inH inW : ℕ} {h1 : inC ≠ 0} {h2 : kH ≠ 0} {h3 : kW ≠ 0} : Ref m α (Spec.Shape.dim inC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar)))) → Ref m α (Spec.Shape.dim outC Spec.Shape.scalar) → Ref m α (Spec.Shape.dim inC (Spec.Shape.dim inH (Spec.Shape.dim inW Spec.Shape.scalar))) → m (Ref m α (Spec.Shape.dim outC (Spec.Shape.dim ((inH - 1) * stride - 2 * padding + kH) (Spec.Shape.dim ((inW - 1) * stride - 2 * padding + kW) Spec.Shape.scalar))))
- bernoulliMask {s : Spec.Shape} : Ref m α Spec.Shape.scalar → (seed : ℕ) → m (Ref m α s)
Instances
Reference type for the current Ops instance.
In eager mode this will typically be TensorRef; in compiled mode it will typically be
GraphM.Var.
Instances For
Re-export of Ops.const. PyTorch: torch.tensor(...) / literal constants.
Instances For
Re-export of Ops.add. PyTorch: torch.add / +.
Instances For
Re-export of Ops.sub. PyTorch: torch.sub / -.
Instances For
Re-export of Ops.mul. PyTorch: torch.mul / *.
Instances For
Re-export of Ops.scale. PyTorch: x * c for a scalar c.
Instances For
Re-export of Ops.abs. PyTorch: torch.abs.
Instances For
Re-export of Ops.sqrt. PyTorch: torch.sqrt.
Instances For
Re-export of Ops.clamp. PyTorch: torch.clamp.
Instances For
Re-export of Ops.max. PyTorch: torch.maximum.
Instances For
Re-export of Ops.min. PyTorch: torch.minimum.
Instances For
Re-export of Ops.broadcastTo. PyTorch: broadcasting / expand.
Instances For
Re-export of Ops.reshape. PyTorch: reshape / view.
Instances For
Re-export of Ops.transpose2d. PyTorch: x.t() / transpose.
Instances For
Re-export of Ops.transpose3d_first_to_last. PyTorch: permute(1,2,0).
Instances For
Re-export of Ops.transpose3d_last_to_first. PyTorch: permute(2,0,1).
Instances For
Re-export of Ops.transpose3d_last_two. PyTorch: transpose(1,2).
Instances For
Re-export of Ops.swapAdjacentAtDepth (general adjacent-axis swap).
Instances For
Re-export of Ops.reduce_sum. PyTorch: torch.sum(..., dim=axis).
Instances For
Re-export of Ops.reduce_mean. PyTorch: torch.mean(..., dim=axis).
Instances For
Re-export of Ops.gather_scalar. PyTorch: x[i] (1D).
Instances For
Re-export of Ops.gather_row. PyTorch: x[i] (2D row).
Instances For
Re-export of Ops.gather_scalar_nat (index is a raw Nat).
Instances For
Re-export of Ops.gather_vec_nat (index tensor).
Instances For
Re-export of Ops.gather_rows_nat (index tensor).
Instances For
Re-export of Ops.scatter_add_vec.
Instances For
Re-export of Ops.scatter_add_row.
Instances For
Re-export of Ops.matmul. PyTorch: torch.matmul for 2D tensors.
Instances For
Re-export of Ops.bmm. PyTorch: torch.bmm.
Instances For
Re-export of Ops.concat_vectors. PyTorch: torch.cat([a,b], dim=0) for vectors.
Instances For
Re-export of Ops.concat_dim0. PyTorch: torch.cat(..., dim=0).
Instances For
Re-export of Ops.slice_range0. PyTorch: x[start:start+len] on the leading dimension.
Instances For
Re-export of Ops.max_pool (generic N-D max pooling, channels-first; no batch axis).
PyTorch comparison: torch.nn.functional.max_pool1d / max_pool2d / max_pool3d depending on the
spatial rank d.
Instances For
Re-export of Ops.avg_pool (generic N-D average pooling, channels-first; no batch axis).
PyTorch comparison: torch.nn.functional.avg_pool1d / avg_pool2d / avg_pool3d depending on the
spatial rank d.
Instances For
Re-export of Ops.smooth_max_pool (generic N-D smooth max pooling, channels-first; no batch axis).
This is a differentiable approximation to max pooling; PyTorch does not expose it as a single primitive.
Instances For
Re-export of Ops.max_pool2d. PyTorch: torch.nn.functional.max_pool2d.
Instances For
Re-export of Ops.max_pool2d_pad. PyTorch: max_pool2d(..., padding=...).
Instances For
Alias for max_pool2d_pad (PyTorch-style shorthand).
Instances For
Re-export of Ops.smooth_max_pool2d (softmax pooling).
Instances For
Re-export of Ops.avg_pool2d. PyTorch: torch.nn.functional.avg_pool2d.
Instances For
Re-export of Ops.avg_pool2d_pad. PyTorch: avg_pool2d(..., padding=...).
Instances For
Alias for avg_pool2d_pad (PyTorch-style shorthand).
Instances For
Re-export of Ops.relu.
Instances For
Re-export of Ops.sigmoid.
Instances For
Re-export of Ops.tanh.
Instances For
Re-export of Ops.softmax.
Instances For
Re-export of Ops.softplus.
Instances For
Re-export of Ops.exp.
Instances For
Re-export of Ops.log.
Instances For
Re-export of Ops.inv (reciprocal).
Instances For
Re-export of Ops.detach. PyTorch: x.detach().
Instances For
Re-export of Ops.safe_log.
Instances For
Re-export of Ops.rand_uniform (deterministic seeded RNG).
Instances For
Re-export of Ops.bernoulli_mask (deterministic dropout-style mask).
Instances For
Stable log_softmax(x) along the last axis.
This is a backend primitive with the standard max-shifted formulation
x - max(x) - log(sum(exp(x - max(x)))), matching PyTorch's numerical intent. The optional
ε parameter is accepted to keep existing call sites stable and is ignored by this primitive;
callers that need an epsilon-smoothed logarithm should use safeLog explicitly.
Instances For
SiLU / swish: x * sigmoid(x).
Instances For
GELU (approximation used by many ML frameworks):
0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3))).
This is defined using existing primitives (tanh, mul, add, scale), so it works in eager,
compiled, and verifier-IR backends without introducing a new opcode.
Instances For
Global average pooling over the last two axes of a C×H×W tensor (channel-first).
Returns a vector C, averaging each channel over H×W.
Instances For
Global average pooling over the last two axes of an N×C×H×W tensor (PyTorch default layout).
Returns N×C, averaging each channel over H×W for each batch element.
Instances For
Re-export of Ops.sum. PyTorch: x.sum().
Instances For
Re-export of Ops.flatten. PyTorch: torch.flatten.
Instances For
Re-export of Ops.linear. PyTorch: torch.nn.functional.linear.
Instances For
Re-export of Ops.mse_loss. PyTorch: torch.nn.functional.mse_loss.
Instances For
Re-export of Ops.layer_norm. PyTorch: nn.LayerNorm / functional.layer_norm.
Instances For
Re-export of Ops.batchnorm_channel_first. PyTorch: nn.BatchNorm2d (conceptually).
Instances For
Re-export of Ops.multi_head_attention.
Instances For
Re-export of Ops.conv (generic N-D convolution, channels-first).
PyTorch comparison: torch.nn.functional.conv{d}d specialized to a single sample (no batch axis).
Instances For
Re-export of Ops.conv_transpose (generic N-D transpose convolution, channels-first).
PyTorch comparison: torch.nn.functional.conv_transpose{d}d specialized to a single sample.
Instances For
Re-export of Ops.conv2d. PyTorch: torch.nn.functional.conv2d (conceptually, no batch axis).
Instances For
Re-export of Ops.conv_transpose2d. PyTorch: torch.nn.functional.conv_transpose2d.