Import Core #
PyTorch import core (JSON parsing).
The Python side of TorchLean round-trips usually writes a JSON object containing nested arrays of
floats (a lightweight, Lean-readable projection of a PyTorch state_dict). The model-agnostic
adapter emitted by NN.Runtime.PyTorch.Export.StateDict is the intended path from .pt / .pth
checkpoints into this JSON format.
Design note:
In typical PyTorch workflows, weights are often serialized via torch.save(model.state_dict(), ...)
or related checkpoint wrappers. TorchLean deliberately avoids parsing those PyTorch binary formats
directly in Lean. Instead, PyTorch loads the checkpoint and emits a small JSON representation that
is easy to validate against a Lean Shape and easy to diff in tests.
This importer is about weights only. Importing a captured graph (e.g. ONNX or torch.export)
is a separate problem and lives at a different abstraction layer than this JSON state_dict shim.
This module is where we keep the shared logic that most PyTorch → TorchLean importers need:
- parse nested JSON arrays into shape-checked
Tensor Float s, - handle a small amount of “state_dict ergonomics” (key lookup, optional wrappers, index parsing),
- keep everything model-agnostic, so the model-specific code can stay small and readable.
Reading map:
parseTensoris the core JSON-to-tensor conversion.loadWeights?andunwrapParamshandle the two JSON layouts we accept.getTensor?/getTensorEare the main lookup helpers used by the model-specific importers.
A PyTorch-style state_dict encoded as a JSON object.
Instances For
defaultTensor s is a zero-filled sentinel used in internal parsing helpers.
It is only used in code paths that are unreachable once we have validated the JSON shape (e.g. after checking an array has exactly the expected length).
Instances For
Parse a JSON value into a Tensor Float s.
The JSON encoding follows the tensor shape:
- scalars are JSON numbers,
Shape.dim n sis a JSON array of lengthnwhose entries recursively encodes.
If the JSON payload does not match the expected shape, we return none.
Instances For
state_dict helpers #
We use JSON objects keyed by strings because that mirrors PyTorch’s state_dict convention.
Some TorchLean Python scripts wrap the object as { "params": { ... } }; loadWeights? accepts
both formats.
If the object contains a "params" field that is itself an object, unwrap it.
We also merge any other top-level fields (e.g. "meta") into the returned dictionary so model
importers can still read them.
Instances For
Load weights from JSON, accepting either:
{ ...state_dict... }, or{ "params": { ...state_dict... } }.
Instances For
Look up a key and parse it as a tensor of a given expected shape.
This is the helper most model-specific importers use to keep the “key wiring” readable.
Instances For
Error-reporting variants (ergonomics) #
Most of the import code in this folder is written in the Option monad to keep examples short.
When you are debugging a round-trip, it is often more helpful to get a concrete reason why an
import failed (missing key vs wrong JSON type vs wrong shape).
The helpers below provide small Except String wrappers around the Option-based core.
Load weights from JSON with an error message on failure.
This is the Except analogue of loadWeights?.
Instances For
Look up a tensor by key, returning a human-friendly error on failure.
This is the Except analogue of getTensor?.
Instances For
Convenience: parse a length-n vector tensor.
Instances For
Convenience: parse a rows × cols matrix tensor.
Instances For
Small parsing helpers used by shape-inferring importers #
Some importers allow variable-width stacks (e.g. a PINN that learns its hidden widths from the checkpoint). Those cases need a little help to infer indices and matrix dimensions from JSON.
Parse keys of the form prefix ++ <nat> ++ suffix.
Example: parseIndexedKey "layers." ".weight" "layers.3.weight" = some 3.
Instances For
Infer (rows, cols) for a JSON matrix encoded as an array of arrays.
This helper infers dimensions from the outer length and first row length. Call sites that need stronger validation (all rows same length) should add an explicit check.
Instances For
Convenience parsers for function-based constructors #
Some call sites build tensors via Spec.vector_tensor / Spec.matrix_tensor, whose inputs are
functions (Fin n → Float and Fin m → Fin n → Float).
parseFloatVec and parseFloatMatrix keep those call sites readable without duplicating JSON
parsing logic outside this core module.