IR axis operations #
IR axis-ops smoke tutorial.
This file is a small “regression guard” for three ops where TorchLean’s IR uses an explicit axis:
softmax axis(PyTorch:torch.softmax(x, dim=axis))concat axis(PyTorch:torch.cat(xs, dim=axis))layernorm axisPyTorch:F.layer_norm(x, normalized_shape=x.shape[axis:])
Why this tutorial exists:
- These three ops are easy to accidentally restrict to “last axis only” (because the spec primitives we reuse are last-axis).
- The denotational IR semantics intentionally supports the PyTorch meaning on any valid axis: it implements non-last axes by reshaping/permuting into a form the spec primitive already supports.
- The compiled IRExec backend is more conservative today. This tutorial runs compiled execution only for supported cases and prints an explicit skip for known backend gaps, instead of treating the backend covers more than it does.
Run:
lake exe torchlean ir_axis_ops --dtype float --backend eager
Test Shapes #
We keep shapes small so this tutorial runs instantly, but still exercises the “axis is not last / not 0” code paths.
Small IR Graphs #
Runner Helpers #
def
NN.Examples.Advanced.IRAxisOps.runOne
{α : Type}
[API.Semantics.Scalar α]
[DecidableEq Spec.Shape]
[ToString α]
[API.Runtime.Scalar α]
(tag : String)
(g : IR.Graph)
(payload : IR.Payload α)
(inputShape : Spec.Shape)
(x : Spec.Tensor α inputShape)
(outputId : Fin g.nodes.size)
(runCompiled : Bool := true)
:
Instances For
def
NN.Examples.Advanced.IRAxisOps.runOnce
{α : Type}
[API.Semantics.Scalar α]
[DecidableEq Spec.Shape]
[ToString α]
[API.Runtime.Scalar α]
(cast : Float → α)
: