TorchLean API

NN.Spec.Module.Flatten

Flatten module wrapper #

flatten_spec converts a tensor of shape s into a vector of length Shape.size s.

Why is the output length computed at the type level?

If you're thinking in PyTorch: this is nn.Flatten() in its simplest form (collapse all dims).

Wrap flatten_spec as an NNModuleSpec (s -> (Shape.size s)).

The dimensions metadata field is not meaningful for flatten because the output length depends on the whole input shape; exporters should recompute the shape from the typed input.

Instances For