Fixed-Sample Training Helpers (API) #
Many runnable examples in NN/Examples/Models/* follow the same pattern:
- build a model with
nn.withModel, - wrap it as a
ScalarModuleDef(model + supervised loss), - load or synthesize one supervised sample
(x, y), - run
stepsoptimizer updates on that fixed sample, and - either print
loss0 -> loss1or write a TrainLog curve.
This module keeps that loop in one place so examples stay short and consistent.
What this is (and is not):
- it is a tutorial helper for fixed-sample runs, not a full dataset trainer;
- it is model-agnostic: callers supply the loss wrapper and optimizer constructor;
- it is backend-agnostic: callers can use it on CPU or CUDA via
TorchLean.Options.
Before/after scalar losses for a fixed-sample training run.
- loss0 : α
- loss1 : α
Instances For
def
NN.API.Models.TrainFixed.instReprLossPair.repr
{α✝ : Type}
[Repr α✝]
:
LossPair α✝ → ℕ → Std.Format
Instances For
def
NN.API.Models.TrainFixed.steps
{α : Type}
[Semantics.Scalar α]
[DecidableEq Shape]
[ToString α]
[Runtime.Scalar α]
[Runtime.Autograd.Torch.Internal.CudaBridge.TensorConv α]
{σ τ : Shape}
(mkModel : nn.M (nn.Sequential σ τ))
(mkModuleDef :
(model : nn.Sequential σ τ) → Runtime.Autograd.TorchLean.ScalarModuleDef (TorchLean.NN.Seq.paramShapes model) [σ, τ])
(mkOptim : (Float → α) → (paramShapes : List Shape) → Runtime.Autograd.TorchLean.Optimizer α paramShapes)
(cast : Float → α)
(opts : Runtime.Autograd.Torch.Options)
(sample : sample.Supervised α σ τ)
(steps : ℕ)
:
One fixed-sample run for an arbitrary scalar backend.
Instances For
def
NN.API.Models.TrainFixed.curveFloat
{σ τ : Shape}
(mkModel : nn.M (nn.Sequential σ τ))
(mkModuleDef :
(model : nn.Sequential σ τ) → Runtime.Autograd.TorchLean.ScalarModuleDef (TorchLean.NN.Seq.paramShapes model) [σ, τ])
(mkOptim : (paramShapes : List Shape) → Runtime.Autograd.TorchLean.Optimizer Float paramShapes)
(opts : Runtime.Autograd.Torch.Options)
(sample : sample.Supervised Float σ τ)
(steps : ℕ)
:
Fixed-sample run specialized to Float, returning a full per-step curve.