TorchLean API

NN.Spec.Layers.Normalization

Normalization layers (spec layer) #

This file collects a few normalization operators used throughout TorchLean's spec/model code.

The common pattern is:

References (papers + PyTorch behavior) #

def Spec.normalizeCore {α : Type} [Context α] (s s_mean s_var s_gamma s_beta : Shape) (epsilon : α) (x : Tensor α s) (mean : Tensor α s_mean) (variance : Tensor α s_var) (gamma : Tensor α s_gamma) (beta : Tensor α s_beta) (cb_mean : s_mean.CanBroadcastTo s) (cb_var : s_var.CanBroadcastTo s) (cb_gamma : s_gamma.CanBroadcastTo s) (cb_beta : s_beta.CanBroadcastTo s) :
Tensor α s

Core normalization routine with explicit broadcast proofs.

This is the shared “math step” behind normalization layers:

y = ((x - mean) / sqrt(variance + ε)) * gamma + beta.

Instances For
    def Spec.layerNorm {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (gamma beta : Tensor α (Shape.dim embedDim Shape.scalar)) (h_seq_pos : seqLen > 0 := by norm_num) (h_embed_pos : embedDim > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
    Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

    LayerNorm over the last dimension of a (seqLen, embedDim) tensor.

    Uses epsilon (default Numbers.epsilon) for numerical stability in the denominator.

    Instances For
      def Spec.layerNormBackward {α : Type} [Context α] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (gamma _beta : Tensor α (Shape.dim embedDim Shape.scalar)) (grad_output : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (epsilon : α := Numbers.epsilon) :

      Backward/VJP for layerNorm (returns (dx, dGamma, dBeta)).

      Instances For
        def Spec.layerNormJvp {α : Type} [Context α] {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x tangent : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (gamma dgamma _beta dbeta : Tensor α (Shape.dim embedDim Shape.scalar)) (epsilon : α := Numbers.epsilon) :
        Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

        Forward-mode JVP for layerNorm.

        For each sequence position, LayerNorm is the map y = gamma ⊙ xhat + beta with xhat = (x - mean(x)) / sqrt(var(x)+eps). The input tangent is normalized by the standard closed form

        dxhat = inv_std ⊙ (dx - mean(dx) - xhat ⊙ mean(dx ⊙ xhat)),

        and affine-parameter tangents contribute xhat ⊙ dgamma + dbeta. This is the forward-mode counterpart of the closed-form VJP above and follows the same clamped-variance convention as the forward pass.

        Instances For
          def Spec.groupNorm {α : Type} [Context α] {batchSize height width channels groups : } (x : Tensor α (Shape.dim batchSize (Shape.dim height (Shape.dim width (Shape.dim channels Shape.scalar))))) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_b : batchSize > 0 := by norm_num) (h_h : height > 0 := by norm_num) (h_w : width > 0 := by norm_num) (h_c : channels > 0 := by norm_num) (h_g : groups > 0 := by norm_num) (h_ge : channels groups) (h_div : channels % groups = 0) (epsilon : α := Numbers.epsilon) :
          Tensor α (Shape.dim batchSize (Shape.dim height (Shape.dim width (Shape.dim channels Shape.scalar))))

          GroupNorm for channel-last tensors (batch, height, width, channels).

          PyTorch analogy: torch.nn.GroupNorm(num_groups=groups, num_channels=channels) applied per sample. The mean/variance are computed over both the spatial dimensions and the channels within each group, then an affine transform is applied per channel via gamma and beta.

          This operator is useful in settings where BatchNorm's dependence on batch statistics is awkward (small batches, verification, or when you want purely per-sample behavior).

          Instances For
            def Spec.normalizeAlongDim {α : Type} [Context α] {s : Shape} (x gamma beta : Tensor α s) (dim : ) (h_valid : Shape.valid_axis_inst dim s) (h_wf : s.WellFormed) (epsilon : α := Numbers.epsilon) :
            Tensor α s

            Normalize along a chosen axis dim of a tensor x, using per-element affine parameters gamma and beta of the same shape as x.

            This is a "generic building block" that is handy in specs; it is closer to the raw math than to a single PyTorch module. Most named normalizations (LayerNorm, GroupNorm, BatchNorm) are special cases of this pattern with a specific choice of axis set and parameter shape.

            Instances For
              def Spec.rmsNorm {α : Type} [Context α] {seqLen embedDim : } (x : Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))) (gamma : Tensor α (Shape.dim embedDim Shape.scalar)) (h_seq_pos : seqLen > 0 := by norm_num) (h_embed_pos : embedDim > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
              Tensor α (Shape.dim seqLen (Shape.dim embedDim Shape.scalar))

              RMSNorm over the last dimension of a (seqLen, embedDim) tensor.

              Compared to LayerNorm, RMSNorm skips subtracting the mean and normalizes by:

              rms(x) = sqrt(mean(x^2) + eps).

              This shows up in many Transformer-style models as a cheaper alternative to LayerNorm.

              Instances For
                def Spec.weightNorm {α : Type} [Context α] {inDim outDim : } (weight : Tensor α (Shape.dim outDim (Shape.dim inDim Shape.scalar))) (gamma : Tensor α (Shape.dim outDim Shape.scalar)) (h_out_pos : outDim > 0 := by norm_num) (h_in_pos : inDim > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :

                WeightNorm for a dense weight matrix (outDim, inDim).

                This implements the "normalize weight vectors then scale" idea:

                • normalize each output row by its L2 norm,
                • then rescale by gamma (one scalar per output row).

                PyTorch analogy: weight normalization is typically applied as a parametrization of a module's weights rather than as a standalone tensor operator.

                Instances For
                  def Spec.batchNorm {α : Type} [Context α] {channels : } {sSpatial : Shape} (x : Tensor α (Shape.dim channels sSpatial)) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) [(Shape.dim channels sSpatial).WellFormed] :
                  Tensor α (Shape.dim channels sSpatial)

                  Stateless BatchNorm for channel-first tensors of shape .dim channels sSpatial.

                  This computes per-channel mean/variance over the sSpatial axes and applies:

                  y = ((x - mean) / sqrt(var + eps)) * gamma + beta.

                  PyTorch analogy: torch.nn.BatchNorm{1,2,3}d in training mode on an input with batch size N=1. TorchLean does not model the running-statistics update here.

                  Instances For
                    def Spec.instanceNorm {α : Type} [Context α] {channels : } {sSpatial : Shape} (x : Tensor α (Shape.dim channels sSpatial)) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) [(Shape.dim channels sSpatial).WellFormed] :
                    Tensor α (Shape.dim channels sSpatial)

                    Alias: per-sample normalization over spatial axes ("InstanceNorm-style").

                    The spec-level batchNorm* operators model the N=1 case (no explicit batch axis and no running statistics update). Many ML codebases refer to that behavior as instance normalization.

                    These aliases make that intent explicit without changing the existing API surface.

                    Instances For
                      def Spec.batchNorm2d {α : Type} [Context α] {channels height width : } (x : MultiChannelImage channels height width α) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_h : height > 0 := by norm_num) (h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                      MultiChannelImage channels height width α

                      Convenience wrapper: batchNorm specialized to a single channel-first image (C,H,W).

                      Instances For
                        def Spec.batchNorm2dJvp {α : Type} [Context α] {channels height width : } (x tangent : MultiChannelImage channels height width α) (gamma dgamma _beta dbeta : Tensor α (Shape.dim channels Shape.scalar)) (_h_c : channels > 0 := by norm_num) (_h_h : height > 0 := by norm_num) (_h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                        MultiChannelImage channels height width α

                        Forward-mode JVP for batchNorm2d.

                        TorchLean's stateless BatchNorm2d computes one set of statistics per channel over the spatial grid. The input tangent therefore uses the same closed-form normalization differential as LayerNorm, but with the mean taken over (height,width) for each channel:

                        dxhat = inv_std * (dx - mean(dx) - xhat * mean(dx*xhat)).

                        Affine tangents contribute xhat * dgamma + dbeta channel-wise.

                        Instances For
                          def Spec.instanceNorm2d {α : Type} [Context α] {channels height width : } (x : MultiChannelImage channels height width α) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_h : height > 0 := by norm_num) (h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                          MultiChannelImage channels height width α

                          Alias for batchNorm2d (InstanceNorm-style per-image normalization).

                          Instances For
                            def Spec.batchNorm1d {α : Type} [Context α] {channels length : } (x : Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_l : length > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                            Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))

                            Convenience wrapper: batchNorm specialized to a (C, L) tensor (BatchNorm1d-style).

                            Instances For
                              def Spec.instanceNorm1d {α : Type} [Context α] {channels length : } (x : Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_l : length > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                              Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))

                              Alias for batchNorm1d (InstanceNorm-style per-sample normalization).

                              Instances For
                                def Spec.batchNorm3d {α : Type} [Context α] {channels depth height width : } (x : Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_d : depth > 0 := by norm_num) (h_h : height > 0 := by norm_num) (h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                                Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))

                                Convenience wrapper: batchNorm specialized to a (C, D, H, W) tensor (BatchNorm3d-style).

                                Instances For
                                  def Spec.instanceNorm3d {α : Type} [Context α] {channels depth height width : } (x : Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))) (gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (h_c : channels > 0 := by norm_num) (h_d : depth > 0 := by norm_num) (h_h : height > 0 := by norm_num) (h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                                  Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))

                                  Alias for batchNorm3d (InstanceNorm-style per-sample normalization).

                                  Instances For
                                    def Spec.batchNorm2dBackward {α : Type} [Context α] {channels height width : } (x : MultiChannelImage channels height width α) (gamma : Tensor α (Shape.dim channels Shape.scalar)) (grad_output : MultiChannelImage channels height width α) (_h_c : channels > 0 := by norm_num) (_h_h : height > 0 := by norm_num) (_h_w : width > 0 := by norm_num) (epsilon : α := Numbers.epsilon) :
                                    MultiChannelImage channels height width α × Tensor α (Shape.dim channels Shape.scalar) × Tensor α (Shape.dim channels Shape.scalar)

                                    Backward/VJP for batchNorm2d.

                                    Returns (dx, dGamma, dBeta). This matches the shape of gradients you expect from a PyTorch-style BatchNorm2d, but note that our forward is the per-image variant (no explicit batch dimension and no running statistics).

                                    Instances For

                                      BatchNorm (inference-time, running statistics) #

                                      PyTorch distinction:

                                      TorchLean keeps things pure and explicit: inference-time BatchNorm takes the running statistics as arguments.

                                      def Spec.batchNormInference {α : Type} [Context α] {channels : } {sSpatial : Shape} (x : Tensor α (Shape.dim channels sSpatial)) (runningMean runningVar gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) :
                                      Tensor α (Shape.dim channels sSpatial)

                                      Inference-time BatchNorm for channel-first tensors of shape .dim channels sSpatial, using fixed running statistics.

                                      Formula (per channel c):

                                      y = ((x - μ) / sqrt(σ² + eps)) * γ + β

                                      This matches the standard evaluation-time behavior of torch.nn.BatchNorm{1,2,3}d (no batch-statistics computation, no running-statistics update).

                                      At inference time, (μ, σ², γ, β) are constants, so this is an affine map in x. See NN.Proofs.Analysis.Normalization.batchNorm_inference_eq_mul_add.

                                      Instances For

                                        Convenience wrapper: batchNorm_inference specialized to a single channel-first image (C,H,W).

                                        def Spec.batchNorm2dInference {α : Type} [Context α] {channels height width : } (x : MultiChannelImage channels height width α) (runningMean runningVar gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) :
                                        MultiChannelImage channels height width α
                                        Instances For
                                          def Spec.batchNorm1dInference {α : Type} [Context α] {channels length : } (x : Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))) (runningMean runningVar gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) :
                                          Tensor α (Shape.dim channels (Shape.dim length Shape.scalar))

                                          Convenience wrapper: batchNorm_inference specialized to a (C, L) tensor (BatchNorm1d-style).

                                          Instances For
                                            def Spec.batchNorm3dInference {α : Type} [Context α] {channels depth height width : } (x : Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))) (runningMean runningVar gamma beta : Tensor α (Shape.dim channels Shape.scalar)) (epsilon : α := Numbers.epsilon) :
                                            Tensor α (Shape.dim channels (Shape.dim depth (Shape.dim height (Shape.dim width Shape.scalar))))

                                            Convenience wrapper: batchNorm_inference specialized to a (C, D, H, W) tensor (BatchNorm3d-style).

                                            Instances For