PyTorch state_dict Bridge #
This module is the general weight-interchange path for PyTorch users.
The important split is:
- Weights move through PyTorch
state_dicts. PyTorch’s own documentation recommends saving a module’s learned parameters withtorch.save(model.state_dict(), path)because that is the most flexible restoration format. - Graphs move through graph capture (
torch.export, FX, ONNX, or TorchLeanNN.IR.Graph). Astate_dictalone does not describe the model architecture; it only names tensors.
Lean should not try to parse PyTorch pickle/zip checkpoints directly. Instead, we emit a small Python
adapter that loads a checkpoint with PyTorch, normalizes common wrappers such as
{"state_dict": ...}, and writes shape-checkable JSON:
{
"params": { "layer.weight": [[...]], "layer.bias": [...] },
"meta": { "layer.weight": { "shape": [out, in], "dtype": "torch.float32" } }
}
NN.Runtime.PyTorch.Import.Core then parses the "params" object into typed TorchLean tensors.
Architecture-specific loaders are still useful, but only for mapping names and shapes. The transport
format itself is model-agnostic.
References:
- PyTorch tutorial, "Saving and Loading Models":
https://docs.pytorch.org/tutorials/beginner/saving_loading_models.html - PyTorch
torch.exportuser guide:https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/export.html - PyTorch FX overview:
https://docs.pytorch.org/docs/stable/fx.html
Options for the generated Python checkpoint-to-JSON adapter.
This adapter is intentionally conservative: it accepts tensor-valued entries only, drops common DataParallel prefixes when requested, and writes plain JSON rather than a PyTorch-specific binary format. That makes the output easy to inspect, diff, and parse in Lean.
- functionName : String
Name of the Python helper function emitted into the generated script.
- stripDataParallelPrefix : Bool
If true, strip a leading
"module."from keys produced bytorch.nn.DataParallel. - includeMeta : Bool
If true, include a
"meta"object with per-key shape and dtype strings. - weightsOnlyExpr : String
Python expression passed as
weights_onlytotorch.load.
Instances For
Instances For
Render a Python boolean literal.
Instances For
Emit a standalone Python script that converts a PyTorch checkpoint into TorchLean JSON.
The script handles three common checkpoint layouts:
- a raw
state_dict; { "state_dict": state_dict };{ "model_state_dict": state_dict }.
Usage of the generated script:
python export_state_dict_json.py model.pt model.json
The resulting model.json is accepted by Import.PyTorch.loadWeights?.