Global pooling as NNModuleSpecs #
Global pooling is a common "bridge" from convolutional feature maps to a classifier head: it reduces the spatial dimensions and keeps only the channel axis.
This file wraps the flat global pooling specs from NN/Spec/Layers/GlobalPooling.lean as
NNModuleSpecs so they can be composed with SpecChain (and recognized by export tooling).
We focus on the flat variants because most model definitions here use them directly:
GlobalAvgPool2DFlatModuleSpecreturns a length-inCvector.GlobalMaxPool2DFlatModuleSpecreturns a length-inCvector.
def
Spec.GlobalAvgPool2DFlatModuleSpec
{α : Type}
[Context α]
{inC inH inW : ℕ}
(hH : inH ≠ 0)
(hW : inW ≠ 0)
(layer : GlobalAvgPool2DSpec := { })
:
ModSpec.NNModuleSpec α (Shape.dim inC (Shape.dim inH (Shape.dim inW Shape.scalar))) (Shape.dim inC Shape.scalar)
Global average pooling (flattened): (C,H,W) -> (C).
Instances For
def
Spec.GlobalMaxPool2DFlatModuleSpec
{α : Type}
[Context α]
{inC inH inW : ℕ}
(hH : inH ≠ 0)
(hW : inW ≠ 0)
(layer : GlobalMaxPool2DSpec := { })
:
ModSpec.NNModuleSpec α (Shape.dim inC (Shape.dim inH (Shape.dim inW Shape.scalar))) (Shape.dim inC Shape.scalar)
Global max pooling (flattened): (C,H,W) -> (C).