CNN PyTorch Fixture Import #
CNN fixture weight import from a PyTorch-style state_dict.
We mirror the common PyTorch naming convention for modules:
conv1.weight,conv1.biasconv2.weight,conv2.biasfc.weight,fc.bias
Each tensor is expected to be encoded as nested JSON arrays whose shape matches the expected
TorchLean Shape.
Typed view of a PyTorch state_dict for the demo 2-block CNN.
This matches the keys used by the exporter (conv1.*, conv2.*, fc.*) and pins down the exact
shapes expected by TorchLean.
- convW1 : Spec.Tensor Float (Spec.Shape.dim outC (Spec.Shape.dim inC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))
First convolution kernel, PyTorch shape
(outC, inC, kH, kW). - convB1 : Spec.Tensor Float (Spec.Shape.dim outC Spec.Shape.scalar)
First convolution bias.
- convW2 : Spec.Tensor Float (Spec.Shape.dim outC (Spec.Shape.dim outC (Spec.Shape.dim kH (Spec.Shape.dim kW Spec.Shape.scalar))))
Second convolution kernel, PyTorch shape
(outC, outC, kH, kW). - convB2 : Spec.Tensor Float (Spec.Shape.dim outC Spec.Shape.scalar)
Second convolution bias.
- linearW : Spec.Tensor Float (Spec.Shape.dim outC (Spec.Shape.dim flatSize Spec.Shape.scalar))
Classifier weight, PyTorch shape
(outC, flatSize). - linearB : Spec.Tensor Float (Spec.Shape.dim outC Spec.Shape.scalar)
Classifier bias.
Instances For
def
Import.CNNPyTorch.loadCnnStateDict
(inC outC kH kW flatSize : ℕ)
(j : Lean.Json)
:
Option (CnnStateDict inC outC kH kW flatSize)
Load a CNN state dict from JSON (PyTorch state_dict-style keys).