TorchLean API

NN.Examples.Interop.PyTorch.Transformer.Export

Transformer PyTorch Fixture Export #

PyTorch code generator for the small Transformer encoder round-trip fixture.

This file produces a readable Python nn.Module implementation that follows the usual PyTorch structure (MHA + residual + LayerNorm + FFN). In the TorchLean repo we mostly use this as a round-trip companion: generate a reference implementation, train/tweak in Python if needed, and optionally export parameters back to Lean via JSON in the importer modules.

def Export.TransformerPyTorch.generateTransformerEncoderPyTorchClass (seqLen embedDim headCount hiddenDim numLayers : ) (className : String := "TransformerEncoder") :

Render a small Transformer encoder as a Python nn.Module class definition.

This produces readable "reference PyTorch" code (MultiHeadAttention + residual + LayerNorm + FFN), useful for round-trip demos.

Instances For
    def Export.TransformerPyTorch.generateTransformerEncoderWithWeights (seqLen embedDim headCount hiddenDim : ) (Wq Wk Wv Wo : Spec.Tensor Float (Spec.Shape.dim embedDim (Spec.Shape.dim embedDim Spec.Shape.scalar))) (W1 : Spec.Tensor Float (Spec.Shape.dim embedDim (Spec.Shape.dim hiddenDim Spec.Shape.scalar))) (W2 : Spec.Tensor Float (Spec.Shape.dim hiddenDim (Spec.Shape.dim embedDim Spec.Shape.scalar))) (b1 : Spec.Tensor Float (Spec.Shape.dim hiddenDim Spec.Shape.scalar)) (b2 norm1_gamma norm1_beta norm2_gamma norm2_beta : Spec.Tensor Float (Spec.Shape.dim embedDim Spec.Shape.scalar)) (className : String := "TransformerEncoder") :

    Generate a single-layer Transformer encoder module with an embedded state_dict initializer.

    This is meant for round-trip demos where parameters are loaded from TorchLean tensors.

    Important convention: NN/Spec transformer weights are stored in the mathematical (in, out) orientation because they are applied as X * W. PyTorch stores linear weights as (out, in) and applies them as X @ W.T + b. So for all matrix-valued weights we print the transpose when populating the PyTorch state_dict.

    Instances For