Normalization layers (spec layer) #
This file collects a few normalization operators used throughout TorchLean's spec/model code.
The common pattern is:
- compute per-axis statistics (mean / variance or RMS),
- normalize with an
epsilonfor numerical stability, - optionally apply an affine transform (
gamma,beta) like PyTorch does.
References (papers + PyTorch behavior) #
LayerNorm: Ba et al., "Layer Normalization" (2016): https://arxiv.org/abs/1607.06450
BatchNorm: Ioffe, Szegedy, "Batch Normalization" (2015): https://arxiv.org/abs/1502.03167
GroupNorm: Wu, He, "Group Normalization" (2018): https://arxiv.org/abs/1803.08494
RMSNorm: Zhang, Sennrich, "Root Mean Square Layer Normalization" (2019): https://arxiv.org/abs/1910.07467
WeightNorm: Salimans, Kingma, "Weight Normalization" (2016): https://arxiv.org/abs/1602.07868
PyTorch LayerNorm: https://docs.pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
PyTorch BatchNorm2d: https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
Core normalization routine with explicit broadcast proofs.
This is the shared “math step” behind normalization layers:
y = ((x - mean) / sqrt(variance + ε)) * gamma + beta.
Instances For
LayerNorm over the last dimension of a (seqLen, embedDim) tensor.
Uses epsilon (default Numbers.epsilon) for numerical stability in the denominator.
Instances For
Backward/VJP for layerNorm (returns (dx, dGamma, dBeta)).
Instances For
Forward-mode JVP for layerNorm.
For each sequence position, LayerNorm is the map
y = gamma ⊙ xhat + beta with xhat = (x - mean(x)) / sqrt(var(x)+eps).
The input tangent is normalized by the standard closed form
dxhat = inv_std ⊙ (dx - mean(dx) - xhat ⊙ mean(dx ⊙ xhat)),
and affine-parameter tangents contribute xhat ⊙ dgamma + dbeta. This is the forward-mode
counterpart of the closed-form VJP above and follows the same clamped-variance convention as the
forward pass.
Instances For
GroupNorm for channel-last tensors (batch, height, width, channels).
PyTorch analogy: torch.nn.GroupNorm(num_groups=groups, num_channels=channels) applied per sample.
The mean/variance are computed over both the spatial dimensions and the channels within each
group, then an affine transform is applied per channel via gamma and beta.
This operator is useful in settings where BatchNorm's dependence on batch statistics is awkward (small batches, verification, or when you want purely per-sample behavior).
Instances For
Normalize along a chosen axis dim of a tensor x, using per-element affine parameters gamma
and beta of the same shape as x.
This is a "generic building block" that is handy in specs; it is closer to the raw math than to a single PyTorch module. Most named normalizations (LayerNorm, GroupNorm, BatchNorm) are special cases of this pattern with a specific choice of axis set and parameter shape.
Instances For
RMSNorm over the last dimension of a (seqLen, embedDim) tensor.
Compared to LayerNorm, RMSNorm skips subtracting the mean and normalizes by:
rms(x) = sqrt(mean(x^2) + eps).
This shows up in many Transformer-style models as a cheaper alternative to LayerNorm.
Instances For
WeightNorm for a dense weight matrix (outDim, inDim).
This implements the "normalize weight vectors then scale" idea:
- normalize each output row by its L2 norm,
- then rescale by
gamma(one scalar per output row).
PyTorch analogy: weight normalization is typically applied as a parametrization of a module's weights rather than as a standalone tensor operator.
Instances For
Stateless BatchNorm for channel-first tensors of shape .dim channels sSpatial.
This computes per-channel mean/variance over the sSpatial axes and applies:
y = ((x - mean) / sqrt(var + eps)) * gamma + beta.
PyTorch analogy: torch.nn.BatchNorm{1,2,3}d in training mode on an input with batch size N=1.
TorchLean does not model the running-statistics update here.
Instances For
Alias: per-sample normalization over spatial axes ("InstanceNorm-style").
The spec-level batchNorm* operators model the N=1 case (no explicit batch axis and no running
statistics update). Many ML codebases refer to that behavior as instance normalization.
These aliases make that intent explicit without changing the existing API surface.
Instances For
Convenience wrapper: batchNorm specialized to a single channel-first image (C,H,W).
Instances For
Forward-mode JVP for batchNorm2d.
TorchLean's stateless BatchNorm2d computes one set of statistics per channel over the spatial
grid. The input tangent therefore uses the same closed-form normalization differential as
LayerNorm, but with the mean taken over (height,width) for each channel:
dxhat = inv_std * (dx - mean(dx) - xhat * mean(dx*xhat)).
Affine tangents contribute xhat * dgamma + dbeta channel-wise.
Instances For
Alias for batchNorm2d (InstanceNorm-style per-image normalization).
Instances For
Convenience wrapper: batchNorm specialized to a (C, L) tensor (BatchNorm1d-style).
Instances For
Alias for batchNorm1d (InstanceNorm-style per-sample normalization).
Instances For
Convenience wrapper: batchNorm specialized to a (C, D, H, W) tensor (BatchNorm3d-style).
Instances For
Alias for batchNorm3d (InstanceNorm-style per-sample normalization).
Instances For
Backward/VJP for batchNorm2d.
Returns (dx, dGamma, dBeta). This matches the shape of gradients you expect from a PyTorch-style
BatchNorm2d, but note that our forward is the per-image variant (no explicit batch dimension and no
running statistics).
Instances For
BatchNorm (inference-time, running statistics) #
PyTorch distinction:
- training: normalize using batch statistics (and update running mean/variance);
- inference: normalize using the stored running mean/variance.
TorchLean keeps things pure and explicit: inference-time BatchNorm takes the running statistics as arguments.
Inference-time BatchNorm for channel-first tensors of shape .dim channels sSpatial, using fixed
running statistics.
Formula (per channel c):
y = ((x - μ) / sqrt(σ² + eps)) * γ + β
This matches the standard evaluation-time behavior of torch.nn.BatchNorm{1,2,3}d (no
batch-statistics computation, no running-statistics update).
At inference time, (μ, σ², γ, β) are constants, so this is an affine map in x. See
NN.Proofs.Analysis.Normalization.batchNorm_inference_eq_mul_add.
Instances For
Convenience wrapper: batchNorm_inference specialized to a single channel-first image (C,H,W).
Instances For
Convenience wrapper: batchNorm_inference specialized to a (C, L) tensor (BatchNorm1d-style).
Instances For
Convenience wrapper: batchNorm_inference specialized to a (C, D, H, W) tensor (BatchNorm3d-style).