MLP PyTorch Fixture Import #
MLP fixture weight import from a PyTorch-style state_dict.
On the Python side we usually write JSON (nested lists of floats) under keys that mirror the names
you would see in model.state_dict():
fc1.weight,fc1.bias,fc2.weight,fc2.biasfor a hand-writtennn.Modulewithfc1/fc2,- or
layers.0.weight,layers.0.bias, ... if the model was built from annn.Sequential.
This file keeps the parsing logic in one place so the rest of the codebase can talk in terms of typed Lean tensors.
Typed view of an MLP PyTorch state_dict (two linear layers).
We keep the tensors as Float because these importers are meant for runtime demos: train in Python,
export to JSON, then run/verify in TorchLean.
- w1 : Spec.Tensor Float (Spec.Shape.dim hidDim (Spec.Shape.dim inDim Spec.Shape.scalar))
First linear layer weight, PyTorch shape
(hidden, input). - b1 : Spec.Tensor Float (Spec.Shape.dim hidDim Spec.Shape.scalar)
Bias for layer 1.
- w2 : Spec.Tensor Float (Spec.Shape.dim outDim (Spec.Shape.dim hidDim Spec.Shape.scalar))
Second linear layer weight, PyTorch shape
(output, hidden). - b2 : Spec.Tensor Float (Spec.Shape.dim outDim Spec.Shape.scalar)
Bias for layer 2.
Instances For
Load an MLP state dict from JSON (accepts both key conventions described above).
Instances For
Construct LinearSpec for Float from an MLP state dict.
Instances For
Convenience: run the Float MLP forward given a state dict and input.