TorchLean API

NN.Runtime.PyTorch.Import.Core

Import Core #

PyTorch import core (JSON parsing).

The Python side of TorchLean round-trips usually writes a JSON object containing nested arrays of floats (a lightweight, Lean-readable projection of a PyTorch state_dict). The model-agnostic adapter emitted by NN.Runtime.PyTorch.Export.StateDict is the intended path from .pt / .pth checkpoints into this JSON format.

Design note:

In typical PyTorch workflows, weights are often serialized via torch.save(model.state_dict(), ...) or related checkpoint wrappers. TorchLean deliberately avoids parsing those PyTorch binary formats directly in Lean. Instead, PyTorch loads the checkpoint and emits a small JSON representation that is easy to validate against a Lean Shape and easy to diff in tests.

This importer is about weights only. Importing a captured graph (e.g. ONNX or torch.export) is a separate problem and lives at a different abstraction layer than this JSON state_dict shim.

This module is where we keep the shared logic that most PyTorch → TorchLean importers need:

Reading map:

@[reducible, inline]

A PyTorch-style state_dict encoded as a JSON object.

Instances For

    defaultTensor s is a zero-filled sentinel used in internal parsing helpers.

    It is only used in code paths that are unreachable once we have validated the JSON shape (e.g. after checking an array has exactly the expected length).

    Instances For

      Parse a JSON value into a Tensor Float s.

      The JSON encoding follows the tensor shape:

      • scalars are JSON numbers,
      • Shape.dim n s is a JSON array of length n whose entries recursively encode s.

      If the JSON payload does not match the expected shape, we return none.

      Instances For

        state_dict helpers #

        We use JSON objects keyed by strings because that mirrors PyTorch’s state_dict convention. Some TorchLean Python scripts wrap the object as { "params": { ... } }; loadWeights? accepts both formats.

        Read a JSON value as a StateDict.

        Instances For

          If the object contains a "params" field that is itself an object, unwrap it.

          We also merge any other top-level fields (e.g. "meta") into the returned dictionary so model importers can still read them.

          Instances For

            Load weights from JSON, accepting either:

            • { ...state_dict... }, or
            • { "params": { ...state_dict... } }.
            Instances For

              Look up a JSON field by key in a StateDict.

              Instances For

                Look up a JSON field and require it to be a JSON object.

                Instances For

                  Look up a JSON field and require it to be a JSON string.

                  Instances For

                    Look up a key and parse it as a tensor of a given expected shape.

                    This is the helper most model-specific importers use to keep the “key wiring” readable.

                    Instances For

                      Error-reporting variants (ergonomics) #

                      Most of the import code in this folder is written in the Option monad to keep examples short. When you are debugging a round-trip, it is often more helpful to get a concrete reason why an import failed (missing key vs wrong JSON type vs wrong shape).

                      The helpers below provide small Except String wrappers around the Option-based core.

                      Load weights from JSON with an error message on failure.

                      This is the Except analogue of loadWeights?.

                      Instances For

                        Look up a tensor by key, returning a human-friendly error on failure.

                        This is the Except analogue of getTensor?.

                        Instances For
                          @[reducible, inline]

                          Convenience: parse a length-n vector tensor.

                          Instances For
                            @[reducible, inline]

                            Convenience: parse a rows × cols matrix tensor.

                            Instances For

                              Small parsing helpers used by shape-inferring importers #

                              Some importers allow variable-width stacks (e.g. a PINN that learns its hidden widths from the checkpoint). Those cases need a little help to infer indices and matrix dimensions from JSON.

                              Parse keys of the form prefix ++ <nat> ++ suffix.

                              Example: parseIndexedKey "layers." ".weight" "layers.3.weight" = some 3.

                              Instances For

                                Infer (rows, cols) for a JSON matrix encoded as an array of arrays.

                                This helper infers dimensions from the outer length and first row length. Call sites that need stronger validation (all rows same length) should add an explicit check.

                                Instances For

                                  Drop the last element of a list of Nat (used when inferring hidden layer widths).

                                  Instances For

                                    Convenience parsers for function-based constructors #

                                    Some call sites build tensors via Spec.vector_tensor / Spec.matrix_tensor, whose inputs are functions (Fin n → Float and Fin m → Fin n → Float).

                                    parseFloatVec and parseFloatMatrix keep those call sites readable without duplicating JSON parsing logic outside this core module.

                                    Parse a JSON array into a Fin n → Float function.

                                    Instances For
                                      def Import.PyTorch.parseFloatMatrix (rows cols : ) (j : Lean.Json) :
                                      Option (Fin rowsFin colsFloat)

                                      Parse a JSON matrix into a Fin rows → Fin cols → Float function.

                                      Instances For