Functional #
TorchLean functional helpers in the style of torch.* building blocks.
These are derived ops built from the small primitive TorchLean.Ops surface, so they work for:
- eager backend (runtime tape), and
- compiled backend (SSA/DAG via
Compiled.GraphM), using the same model/loss definition.
The goal is to make losses readable without forcing users to call specialized ops like
mse_loss directly.
PyTorch References #
torch.nn.functional: https://pytorch.org/docs/stable/nn.functional.htmltorch.autograd(detach/stop-grad concepts): https://pytorch.org/docs/stable/autograd.htmltorch.utils.checkpoint: https://pytorch.org/docs/stable/checkpoint.html
AD References #
For background on reverse-mode AD (the idea behind tape-based autograd), see:
- Andreas Griewank and Andrea Walther, Evaluating Derivatives, 2nd ed., 2008.
- Seppo Linnainmaa, 1970 (reverse accumulation / the classic precursor to modern backprop).
Elementwise helpers #
Elementwise square: x ↦ x * x.
PyTorch analogue: torch.square.
Instances For
Checkpointing (semantics-first identity wrapper) #
Checkpoint wrapper for API parity with PyTorch-style memory-saving patterns.
In this codebase, checkpointing is a semantic identity wrapper (checkpoint f x = f x). Backends
that implement recomputation can refine this hook without changing the mathematical meaning.
Instances For
Detach / stop-grad #
Stop-gradient boundary (forward identity).
Instances For
Alias for detach.
Instances For
Broadcasting helpers #
Broadcasting add: compute x + y after broadcasting both inputs to the target shape t.
PyTorch analogue: torch.add (broadcasting semantics).
Instances For
Broadcasting multiply: compute x * y after broadcasting both inputs to the target shape t.
PyTorch analogue: torch.mul (broadcasting semantics).
Instances For
Indexing helpers #
Embedding lookup (gather one row of an embedding table).
Given w : vocab × dim, return w[idx] : dim.
PyTorch analogue: torch.nn.functional.embedding for a single index.
Instances For
Reductions #
Mean reduction: mean(x) = sum(x) / numel(x).
PyTorch analogue: torch.mean.
Instances For
Seeded RNG helpers #
Deterministic U[0,1) tensor generator (seeded).
Instances For
Deterministic {0,1} mask generator (seeded) with scalar keep-probability input.
Instances For
Seeded dropout implemented as x * mask / keepProb where mask ∈ {0,1} is sampled from a
deterministic PRNG keyed by seed.
Instances For
Seeded dropout where the probability is supplied as a scalar tensor ref.
This is useful in model builders where the layer definition stores p as data, avoiding an
ad-hoc Float → α cast in polymorphic model code.
Instances For
Einsum-ish building blocks #
Matrix matmul: [m,n] × [n,p] → [m,p].
Instances For
Batched matmul: [batch,m,n] × [batch,n,p] → [batch,m,p].
Instances For
Typed einsum wrappers (fast, total) #
These are non-Option equivalents for the most common einsum contractions in ML code.
They are intended to be used directly (no string parsing), and serve as the fast-path targets for
einsumDyn.
einsum("ij,jk->ik", A, B) as a typed matmul.
Instances For
einsum("bij,bjk->bik", A, B) as a typed batched matmul.
Instances For
Einsum pattern used in attention: bhid,bhjd -> bhij (batched Q·Kᵀ per head).
Instances For
Einsum pattern used in attention: bhij,bhjd -> bhid (batched Attn·V per head).
Instances For
General einsum (PyTorch-style subscripts; runtime-checked) #
Decidable instance for Shape.well_formed, used by the dynamic einsum lowering.
Instances For
Local decidability instance for Shape.well_formed (used by the dynamic einsum lowering).
Instances For
Instances For
Instances For
Remove ASCII whitespace to simplify the hand-rolled parser.
Instances For
Instances For
Detect whether a list of labels contains any duplicates (order-preserving scan).
Instances For
Find the first index of a label (like List.findIdx?, but returning an Option Nat).
Instances For
Convert a permutation of axes into a sequence of adjacent swaps.
This mirrors the IR-side lowering strategy: represent a general permutation as a list of swaps at
depths, then implement swaps with swapAdjacentAtDepth.
Instances For
Expand an input operand’s labels to a full label list matching the operand’s rank.
If the subscript contains an ellipsis, this inserts fresh Label.ell labels so that the total
label count matches Shape.rank s.
Instances For
Small association-list helpers #
To keep this file dependency-light, we represent maps as association lists and use small helpers
instead of Std.HashMap.
Occurrence counts for labels, represented as an association list.
Instances For
Increment a label’s count (inserting it if absent).
Instances For
Instances For
Map each label to its concrete dimension size (association list).
Instances For
Infer a consistent label-to-dimension map from operand label lists and operand shapes.
This implements standard einsum broadcast rules: if a label is seen with both d and 1, we keep
d; if two non-1 sizes disagree, we error.
Instances For
Compute a Shape.CanBroadcastTo witness at runtime.
This mirrors the typeclass-based broadcasting used elsewhere, but returns an Option so the dynamic
einsum lowering can fail gracefully.
Instances For
Apply a permutation expressed as adjacent swap depths to an existentially-shaped tensor.
This is the runtime “apply swaps” primitive used by both .permute and the dynamic einsum output
reordering.
Instances For
Reflexive broadcast witness constructor (s can always broadcast to itself).
Instances For
Compute a permutation that maps src to tgt when duplicates are present.
This is used for the “diagonal embedding” case when output labels contain repeats: we temporarily expand the output with extra axes, then permute back to the requested (possibly-duplicated) order.
Instances For
Build a diagonal mask tensor (spec-level) for diagonal embedding/extraction.
Given axes p and q, the resulting tensor is 1 when the indices along those axes agree, and 0
otherwise. The ip/iq parameters track the first-seen indices while recursing over dimensions.
Instances For
Convenience wrapper around diagMaskSpecAux with fresh index-tracking state.
Instances For
Shape.ofList is a left-inverse of Shape.toList.
Specialize diagMaskSpec to a concrete Shape by rewriting through Shape.toList.
Instances For
Shape.appendDim s 1 preserves Shape.size (used to justify reshape tricks).
Permutation list that moves axis to the last position (keeping relative order of others).
Instances For
Runtime-checked einsum that returns an existential output shape.
Supported:
- multiple inputs, explicit/implicit output, and ellipsis (
...). - repeated labels within an operand (diagonal extraction / trace semantics).
- repeated labels in the output (diagonal embedding / zeroing off-diagonal entries).
Currently unsupported (returns none):
- non-broadcastable size mismatches.
- any case that would require gather/scatter-style indexing (not in the verifier-friendly op set).
This is implemented purely by reordering, reshaping, broadcasting, elementwise multiplication, and summing contracted axes.
Instances For
Instances For
Instances For
Instances For
einsum with an expected output shape.
Instances For
Shape/axis helpers #
Swap two adjacent axes at a given nesting depth.
This is the primitive used to implement general permutations via a sequence of adjacent swaps.
It corresponds to the backend op Torch.swapAdjacentAtDepth.
Instances For
Core tensor semantics (PyTorch-style) #
Swap depths that move an axis to the last position (for “reduce along axis” lowering).
Instances For
Swap depths that move an axis to the front position.
Instances For
Decidable Shape.well_formed for the dynamic reduction/slicing helpers.
Instances For
Local decidability instance for Shape.well_formed (used by dynamic reduction/slicing helpers).
Shape.appendDim s 1 preserves size (used to justify reshape in unsqueeze/keepdim code).
Dynamic permutation: like permute, but returns an existential output shape.
PyTorch analogue: torch.permute / Tensor.permute (with runtime checks).
Instances For
Permutation with an expected output shape.
This is a thin wrapper over permuteDyn that checks the computed shape equals sOut.
Instances For
Reduce along the last axis with sum, returning the new (existential) shape.
This is the primitive step used by reduceDimsDynCore after it has permuted the requested axis to
the last position.
Instances For
Like reduceAlongLastSum, but using mean.
Instances For
Core implementation for dynamic reductions over multiple axes.
This lowers “reduce along axis k” to:
- permute axis
kto the last position, - call
reduceLast, and - optionally re-insert a singleton dimension when
keepdim = true.
reduce_sum_dimsDyn and reduce_mean_dimsDyn are just specializations.
Instances For
Dynamic multi-axis sum reduction (like torch.sum(x, dim=axes, keepdim=...)).
Instances For
Dynamic multi-axis mean reduction (like torch.mean(x, dim=axes, keepdim=...)).
Instances For
Dynamic slice on an arbitrary axis.
This lowers slice_range_axisDyn axis start len to:
- permute
axisto the front, - call the axis-0 slice primitive, then
- permute back.
Instances For
Dynamic softmax over an arbitrary axis (implemented by permuting to last, applying softmax,
permuting back).
Instances For
Dynamic log_softmax over an arbitrary axis (with optional epsilon for numerical stability).
Instances For
Dynamic unsqueeze: insert a singleton dimension at axis.
PyTorch analogue: torch.unsqueeze(x, dim=axis).
Instances For
Dynamic squeeze along a specific axis, requiring that axis to have size 1.
PyTorch analogue: torch.squeeze(x, dim=axis) (the dim-restricted variant).
Instances For
Dynamic concatenation of two tensors along axis (existential output shape).
This is the binary helper used by cat_axisDyn. It lowers to concat_dim0 by moving the
requested axis to the front.
Instances For
Dynamic concatenation of a list of tensors along axis (folding cat_axis2Dyn).
Instances For
Dynamic stack along a new axis.
PyTorch analogue: torch.stack(xs, dim=axis).
Implementation: unsqueeze each input at axis, then cat along the same axis.
Instances For
Dynamic split along an axis with explicit split sizes.
PyTorch analogue: torch.split(x, split_sizes, dim=axis).
Instances For
Dynamic chunk along an axis, given a desired chunk size.
PyTorch analogue: torch.split(x, chunkSize, dim=axis) or torch.chunk (size-based variant).
Instances For
NCHW → NHWC for 4D tensors, implemented via two adjacent swaps.
Instances For
NHWC → NCHW for 4D tensors, implemented via two adjacent swaps.