BatchNormChannelFirst #
Pointwise analytic correctness for a channel-first BatchNorm-like graph.
This matches the existing spec/runtime operator Spec.batchNorm_channel_first used by
Runtime.Autograd.Tape.batchnorm_channel_first:
it normalizes each channel independently by computing mean/variance over the spatial
dimensions (H,W) and then applying per-channel affine parameters (gamma,beta).
The proof is spec-level over ℝ. Because the graph uses sqrt (max x 0) and inv, the
statement is pointwise (GraphFDerivCorrectAt) with explicit domain assumptions.
Note: this is not PyTorch BatchNorm2d over N×H×W with running statistics; it is closer to
an InstanceNorm/GroupNorm-style normalization over spatial dimensions per channel.
PyTorch correspondence / citations #
- Reference
BatchNorm2dandInstanceNorm2ddocs (for naming/background; this file is a simpler, stateless normalization graph). https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html
Channel-first tensor shape C×H×W.
Instances For
Matrix shape m×n.
Instances For
Vector shape k.
Instances For
Input context shapes: [x, gamma, beta] with x : C×H×W and gamma/beta : C.
Instances For
Flattened spatial size H*W.
Instances For
Prefix intermediates up to var_eps (after flattening spatial dimensions).
Instances For
Prefix intermediates up to std (adds one more vector).
Instances For
Full list of intermediates for the BatchNormChannelFirst graph in this file.
Instances For
Index of the input x in the base BatchNorm context ΓBN channels height width ++ ss.
Instances For
Index of the scale vector gamma in the base BatchNorm context ΓBN ++ ss.
Instances For
Index of the shift vector beta in the base BatchNorm context ΓBN ++ ss.
Instances For
Index helper for the last element of an extended context Γ ++ ss ++ [τ].
Instances For
Informal computation (per channel, flattening spatial dims):
Let x : C×H×W, flatten spatial dims to xMat : C×(H*W). Then for each channel c:
mean_c := (1/(H*W)) * ∑_{p} xMat[c,p]
centered := xMat - mean_b
var_c := (1/(H*W)) * ∑_{p} centered[c,p]^2
std_c := sqrt(var_c + ε) (implemented as sqrt_clamp)
inv_std_c := 1/std_c
normalized := centered ⊙ inv_std_b
scaled := normalized ⊙ gamma_b
yMat := scaled + beta_b
yChw := reshape yMat back to C×H×W
This is the stateless, per-example normalization used by the runtime spec
Spec.batchNorm_channel_first; it is closer to InstanceNorm/GroupNorm than to BatchNorm with
running statistics.
Index of mean_b in the extended context.
Instances For
Index of xMat in the extended context at stage g3.
Instances For
Center: centered := xMat - mean_b.
Instances For
Graph prefix producing [xMat, mean, mean_b, centered].
Instances For
Index of centered in the extended context at stage g4.
Instances For
Square centered: centered_sq := centered ⊙ centered.
Instances For
Graph prefix producing [xMat, mean, mean_b, centered, centered_sq].
Instances For
Index of centered_sq in the extended context at stage g5.
Instances For
Per-channel variance over spatial dims: var := mean(centered_sq) producing a length-channels
vector.
Instances For
Graph prefix producing [xMat, mean, mean_b, centered, centered_sq, var].
Instances For
Index of var in the extended context at stage g6.
Instances For
Add epsilon: var_eps := var + ε.
Instances For
Graph prefix computing ssPrefixVarEps (up to var_eps).
Instances For
Index of var_eps in ΓBN ++ ssPrefixVarEps.
Instances For
Standard deviation: std := sqrt_clamp(var_eps).
This is where the development becomes pointwise: differentiability depends on positivity of
var_eps.
Instances For
Graph prefix computing ssPrefixStd (adds std).
Instances For
Index of std in ΓBN ++ ssPrefixStd.
Instances For
Inverse standard deviation: inv_std := 1/std.
Instances For
Broadcast inv_std back to C×(H*W) (row-wise), producing inv_std_b.
Instances For
Index of centered in the context at stage g9.
This is obtained by weakening the earlier idxCentered along the extended intermediate list.
Instances For
Index of inv_std_b in the context at stage g9.
Instances For
Normalize: normalized := centered ⊙ inv_std_b.
Instances For
Graph prefix adding normalized.
Instances For
Index of normalized in the extended context at stage g10.
Instances For
Broadcast gamma : C to C×(H*W) (row-wise), producing gamma_b.
Instances For
Graph prefix adding gamma_b.
Instances For
Index of gamma_b in the extended context at stage g11.
Instances For
Scale: scaled := normalized ⊙ gamma_b.
Instances For
Graph prefix adding scaled.
Instances For
Index of scaled in the extended context at stage g12.
Instances For
Broadcast beta : C to C×(H*W) (row-wise), producing beta_b.
Instances For
Graph prefix adding beta_b.
Instances For
Index of beta_b in the extended context at stage g13.
Instances For
Add bias: yMat := scaled + beta_b.
Instances For
Graph prefix adding yMat.
Instances For
Index of yMat in the extended context after g14.
Instances For
Reshape the matrix output yMat : C×(H*W) back into yChw : C×H×W.
Instances For
Full BatchNormChannelFirst graph (explicit snoc chain).
Instances For
Pointwise proof that batchNormGraph satisfies GraphFDerivCorrectAt.
As with LayerNorm, the hypotheses are explicit domain assumptions needed for differentiability
of sqrt (after clamp) and inv at the actual execution point.
Instances For
Pointwise end-to-end result: backprop equals (fderiv eval)† for batchNormGraph.
This is the BatchNormChannelFirst analogue of the global DAG theorem, specialized to the explicit
graph construction and with explicit domain assumptions for sqrt/inv.