Spec modules (NNModuleSpec) and shape-safe composition #
The NN/Spec/Layers/* files define reference layer specs: usually a parameter record plus a pure
forward (and sometimes explicit gradient formulas).
This file packages those specs into a small, uniform module interface:
NNModuleSpec α inShape outShape
which is just a forward function plus lightweight metadata (kind, export_func)
used by tooling (export/extraction) and by the runtime/IR pipeline described in
the TorchLean paper (arXiv:2602.22631).
We keep that metadata separate from the semantics: changing kind/toPyTorch should never change
what forward means.
SpecChain is a dependent composition operator that enforces intermediate shape agreement at
compile time, so you can build pipelines without runtime shape casts.
Mental model (PyTorch analogy):
NNModuleSpecis like a compact, purenn.Modulewith just aforward.SpecChainis likenn.Sequential, but shape-safe by construction.
Diagram:
x : Tensor α s
|
v
[m1 : s -> t] forward
|
v
[m2 : t -> u] forward
|
v
y : Tensor α u
The point of all this is practical: you want shape mistakes to show up as type errors, not as runtime exceptions.
Export-related metadata carried alongside a spec module.
This is intentionally informal (mostly strings). It is not part of the math we prove about a model. We use it for demos and for "roughly equivalent PyTorch" pretty-printing.
- toPyTorch : String
A PyTorch-style rendering for docs/demos (metadata only).
Extra integer metadata used by some exporters (interpretation depends on
kind).
Instances For
A pure module: from inShape to outShape without runtime state.
- forward : Spec.Tensor α inShape → Spec.Tensor α outShape
forward.
- kind : String
Tag used by export/extraction tooling (metadata only).
- export_func : ExportFunctions
Export-related metadata (metadata only).
Instances For
Dependent chain of spec modules ensuring intermediate shapes match at compile-time. Use this for shape-safe composition without runtime casting.
- single {α : Type} {s t : Spec.Shape} (m : NNModuleSpec α s t) : SpecChain α s t
- comp {α : Type} {s t u : Spec.Shape} (a : SpecChain α s t) (b : SpecChain α t u) : SpecChain α s u
Instances For
Forward evaluation over a SpecChain by structural composition.
Instances For
Right-associative composition helper.
This is the ergonomic "append a module to a chain" operator used at call sites:
net : SpecChain α s t
net |>.compose_right m2 |>.compose_right m3
Instances For
Extract the list of kind tags from a chain (left-to-right).
Instances For
Extract (kind, toPyTorch) pairs from a chain (left-to-right).
Instances For
Lift a module to apply independently over a leading sequence dimension.
This is a "map over time" helper for sequence models:
- input has shape
(seqLen, elemIn), - output has shape
(seqLen, elemOut), - and we apply the same module at each timestep.
In PyTorch terms: torch.vmap-style mapping, or a common pattern like:
ys = [m(x_t) for t in range(T)]
We keep this as a single helper so call sites stay small and the intent is obvious.