Runtime Autograd Helpers #
Model-shaped and function-shaped differentiation helpers over TorchLean runtime programs.
Parameter pack type for a given model (a TensorPack over Seq.paramShapes).
Instances For
Loss function over a model output and a target.
This is expressed in terms of RefTy so it works uniformly for eager execution and compiled
execution.
Instances For
Initialize model parameters by casting the model's Float initializers elementwise using cast.
Instances For
Initialize model parameters using the runtime literal injection API.Runtime.ofFloat.
Instances For
Pack explicit weight and bias tensors for a single Layers.linear model.
Instances For
Mean-squared error loss (mse) between yhat and y.
Instances For
Cross-entropy loss between logits and one-hot targets. PyTorch analogue: nn.CrossEntropyLoss.
Instances For
Detach the model output before feeding it into a loss.
This is useful when you want to compute a metric loss without backpropagating through it.
Instances For
Build a TorchLean Program that computes a scalar loss from (params, x, target).
This is the bridge between Seq.program (which produces model outputs) and the autograd entry
points (which expect a scalar-valued program).
Instances For
VJP of the model output w.r.t. parameters.
Instances For
VJP of the model output w.r.t. inputs.
Instances For
Jacobian (reverse-mode) of the model output w.r.t. parameters, returned as rows.
Instances For
Gradient of loss(model(params, x), target) w.r.t. parameters.
Instances For
Gradient of loss(model(params, x), target) w.r.t. inputs (x and target).
Instances For
JVP of a scalar loss w.r.t. parameters in direction vparams.
Instances For
HVP (Hessian-vector product) of a scalar loss w.r.t. parameters in direction vparams.
Instances For
Type of a pure tensor function expressed in RefTy form.
This matches the calling convention expected by TorchLean.Program/autodiff compilation.
Instances For
Turn an Fn into a single-input TorchLean Program.
Instances For
Forward-mode Jacobian (rows) of a pure function.
Instances For
Hessian for a scalar-valued pure function.