TorchLean API

NN.Examples.Interop.PyTorch.MLP.Import

MLP PyTorch Fixture Import #

MLP fixture weight import from a PyTorch-style state_dict.

On the Python side we usually write JSON (nested lists of floats) under keys that mirror the names you would see in model.state_dict():

This file keeps the parsing logic in one place so the rest of the codebase can talk in terms of typed Lean tensors.

structure Import.MLPPyTorch.MlpStateDict (inDim hidDim outDim : ) :

Typed view of an MLP PyTorch state_dict (two linear layers).

We keep the tensors as Float because these importers are meant for runtime demos: train in Python, export to JSON, then run/verify in TorchLean.

Instances For
    def Import.MLPPyTorch.loadMlpStateDict (inDim hidDim outDim : ) (j : Lean.Json) :
    Option (MlpStateDict inDim hidDim outDim)

    Load an MLP state dict from JSON (accepts both key conventions described above).

    Instances For
      def Import.MLPPyTorch.toLinearSpecs {inDim hidDim outDim : } (sd : MlpStateDict inDim hidDim outDim) :
      Spec.LinearSpec Float inDim hidDim × Spec.LinearSpec Float hidDim outDim

      Construct LinearSpec for Float from an MLP state dict.

      Instances For

        Convenience: run the Float MLP forward given a state dict and input.

        Instances For