TorchLean API

NN.Examples.Interop.PyTorch.Transformer.Import

Transformer PyTorch Fixture Import #

Transformer fixture weight import from JSON.

In the spec layer, our transformer encoder parameters are explicit tensors (query/key/value/output projections, feed-forward weights, and LayerNorm affine parameters). In PyTorch these are usually spread across multiple nn.Linear and nn.LayerNorm submodules.

For round-trip demos we accept a stable, explicit key format in JSON: Wq, Wk, Wv, Wo, W1, W2, b1, b2, norm1_gamma, norm1_beta, norm2_gamma, norm2_beta.

We also accept the nested PyTorch module keys emitted by Export.TransformerPyTorch.generateTransformerEncoderWithWeights, such as layers.0.mha.q_proj.weight. That keeps generated export state dicts loadable by both PyTorch and this Lean importer.

structure Import.TransformerPyTorch.TransformerEncoderStateDict (embedDim headCount hiddenDim : ) :

Typed view of a single-layer Transformer encoder state_dict (Float tensors).

This is the normalized typed view returned by the JSON loader. The loader accepts both TorchLean's explicit keys and the nested PyTorch module keys emitted by the exporter.

Instances For
    def Import.TransformerPyTorch.loadTransformerEncoderStateDict (embedDim headCount hiddenDim : ) (j : Lean.Json) :
    Option (TransformerEncoderStateDict embedDim headCount hiddenDim)

    Load Transformer Encoder state dict from JSON matching either supported export key format.

    Instances For