TorchLean API

NN.Proofs.Autograd.Tape.Ops.Norm.BatchNormChannelFirst

BatchNormChannelFirst #

Pointwise analytic correctness for a channel-first BatchNorm-like graph.

This matches the existing spec/runtime operator Spec.batchNorm_channel_first used by Runtime.Autograd.Tape.batchnorm_channel_first: it normalizes each channel independently by computing mean/variance over the spatial dimensions (H,W) and then applying per-channel affine parameters (gamma,beta).

The proof is spec-level over . Because the graph uses sqrt (max x 0) and inv, the statement is pointwise (GraphFDerivCorrectAt) with explicit domain assumptions.

Note: this is not PyTorch BatchNorm2d over N×H×W with running statistics; it is closer to an InstanceNorm/GroupNorm-style normalization over spatial dimensions per channel.

PyTorch correspondence / citations #

@[reducible, inline]

Channel-first tensor shape C×H×W.

Instances For
    @[reducible, inline]

    Matrix shape m×n.

    Instances For
      @[reducible, inline]

      Vector shape k.

      Instances For
        @[reducible, inline]

        Input context shapes: [x, gamma, beta] with x : C×H×W and gamma/beta : C.

        Instances For
          @[reducible, inline]

          Flattened spatial size H*W.

          Instances For
            @[reducible, inline]

            Prefix intermediates up to var_eps (after flattening spatial dimensions).

            Instances For
              @[reducible, inline]

              Prefix intermediates up to std (adds one more vector).

              Instances For
                @[reducible, inline]

                Full list of intermediates for the BatchNormChannelFirst graph in this file.

                Instances For
                  def Proofs.Autograd.BatchNormChannelFirst.idxX {channels height width : } {ss : List Spec.Shape} :
                  Idx (ΓBN channels height width ++ ss) (CHWShape channels height width)

                  Index of the input x in the base BatchNorm context ΓBN channels height width ++ ss.

                  Instances For
                    def Proofs.Autograd.BatchNormChannelFirst.idxGamma {channels height width : } {ss : List Spec.Shape} :
                    Idx (ΓBN channels height width ++ ss) (VecShape channels)

                    Index of the scale vector gamma in the base BatchNorm context ΓBN ++ ss.

                    Instances For
                      def Proofs.Autograd.BatchNormChannelFirst.idxBeta {channels height width : } {ss : List Spec.Shape} :
                      Idx (ΓBN channels height width ++ ss) (VecShape channels)

                      Index of the shift vector beta in the base BatchNorm context ΓBN ++ ss.

                      Instances For

                        Index helper for the last element of an extended context Γ ++ ss ++ [τ].

                        Instances For

                          Informal computation (per channel, flattening spatial dims):

                          Let x : C×H×W, flatten spatial dims to xMat : C×(H*W). Then for each channel c:

                          mean_c := (1/(H*W)) * ∑_{p} xMat[c,p] centered := xMat - mean_b var_c := (1/(H*W)) * ∑_{p} centered[c,p]^2 std_c := sqrt(var_c + ε) (implemented as sqrt_clamp) inv_std_c := 1/std_c normalized := centered ⊙ inv_std_b scaled := normalized ⊙ gamma_b yMat := scaled + beta_b yChw := reshape yMat back to C×H×W

                          This is the stateless, per-example normalization used by the runtime spec Spec.batchNorm_channel_first; it is closer to InstanceNorm/GroupNorm than to BatchNorm with running statistics.

                          theorem Proofs.Autograd.BatchNormChannelFirst.hsz_chw_mat {channels height width : } :
                          (CHWShape channels height width).size = (MatShape channels (hw height width)).size
                          noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeXMat {channels height width : } :
                          Node (ΓBN channels height width) (MatShape channels (hw height width))

                          Reshape x : C×H×W into a matrix xMat : C×(H*W) (flatten spatial dimensions).

                          Instances For
                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.g1 {channels height width : } :
                            Graph (ΓBN channels height width) [MatShape channels (hw height width)]

                            Graph prefix producing [xMat].

                            Instances For
                              def Proofs.Autograd.BatchNormChannelFirst.idxXMat {channels height width : } :
                              Idx (ΓBN channels height width ++ [MatShape channels (hw height width)]) (MatShape channels (hw height width))

                              Index of xMat in ΓBN ++ [xMat].

                              Instances For
                                noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeMean {channels height width : } :
                                Node (ΓBN channels height width ++ [MatShape channels (hw height width)]) (VecShape channels)

                                Per-channel mean over spatial dims: mean : C×(H*W) → C.

                                Instances For
                                  noncomputable def Proofs.Autograd.BatchNormChannelFirst.g2 {channels height width : } :
                                  Graph (ΓBN channels height width) [MatShape channels (hw height width), VecShape channels]

                                  Graph prefix producing [xMat, mean].

                                  Instances For
                                    def Proofs.Autograd.BatchNormChannelFirst.idxMean {channels height width : } :
                                    Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels]) (VecShape channels)

                                    Index of mean in ΓBN ++ [xMat, mean].

                                    Instances For
                                      noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeMeanB {channels height width : } :
                                      Node (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels]) (MatShape channels (hw height width))

                                      Broadcast mean back to C×(H*W) (row-wise).

                                      Instances For
                                        noncomputable def Proofs.Autograd.BatchNormChannelFirst.g3 {channels height width : } :
                                        Graph (ΓBN channels height width) [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width)]

                                        Graph prefix producing [xMat, mean, mean_b].

                                        Instances For
                                          def Proofs.Autograd.BatchNormChannelFirst.idxMeanB {channels height width : } :
                                          Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                          Index of mean_b in the extended context.

                                          Instances For
                                            def Proofs.Autograd.BatchNormChannelFirst.idxXMat3 {channels height width : } :
                                            Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                            Index of xMat in the extended context at stage g3.

                                            Instances For
                                              noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeCentered {channels height width : } :
                                              Node (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                              Center: centered := xMat - mean_b.

                                              Instances For
                                                noncomputable def Proofs.Autograd.BatchNormChannelFirst.g4 {channels height width : } :
                                                Graph (ΓBN channels height width) [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)]

                                                Graph prefix producing [xMat, mean, mean_b, centered].

                                                Instances For
                                                  def Proofs.Autograd.BatchNormChannelFirst.idxCentered {channels height width : } :
                                                  Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                                  Index of centered in the extended context at stage g4.

                                                  Instances For
                                                    noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeCenteredSq {channels height width : } :
                                                    Node (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                                    Square centered: centered_sq := centered ⊙ centered.

                                                    Instances For
                                                      noncomputable def Proofs.Autograd.BatchNormChannelFirst.g5 {channels height width : } :
                                                      Graph (ΓBN channels height width) [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)]

                                                      Graph prefix producing [xMat, mean, mean_b, centered, centered_sq].

                                                      Instances For
                                                        def Proofs.Autograd.BatchNormChannelFirst.idxCenteredSq {channels height width : } :
                                                        Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)]) (MatShape channels (hw height width))

                                                        Index of centered_sq in the extended context at stage g5.

                                                        Instances For
                                                          noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeVar {channels height width : } :
                                                          Node (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)]) (VecShape channels)

                                                          Per-channel variance over spatial dims: var := mean(centered_sq) producing a length-channels vector.

                                                          Instances For
                                                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.g6 {channels height width : } :
                                                            Graph (ΓBN channels height width) [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), VecShape channels]

                                                            Graph prefix producing [xMat, mean, mean_b, centered, centered_sq, var].

                                                            Instances For
                                                              def Proofs.Autograd.BatchNormChannelFirst.idxVar {channels height width : } :
                                                              Idx (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), VecShape channels]) (VecShape channels)

                                                              Index of var in the extended context at stage g6.

                                                              Instances For
                                                                noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeVarEps {channels height width : } (ε : ) :
                                                                Node (ΓBN channels height width ++ [MatShape channels (hw height width), VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), VecShape channels]) (VecShape channels)

                                                                Add epsilon: var_eps := var + ε.

                                                                Instances For
                                                                  noncomputable def Proofs.Autograd.BatchNormChannelFirst.batchNormPrefixVarEps {channels height width : } (ε : ) :
                                                                  Graph (ΓBN channels height width) (ssPrefixVarEps channels height width)

                                                                  Graph prefix computing ssPrefixVarEps (up to var_eps).

                                                                  Instances For
                                                                    def Proofs.Autograd.BatchNormChannelFirst.idxVarEps {channels height width : } :
                                                                    Idx (ΓBN channels height width ++ ssPrefixVarEps channels height width) (VecShape channels)

                                                                    Index of var_eps in ΓBN ++ ssPrefixVarEps.

                                                                    Instances For
                                                                      noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeStd {channels height width : } :
                                                                      Node (ΓBN channels height width ++ ssPrefixVarEps channels height width) (VecShape channels)

                                                                      Standard deviation: std := sqrt_clamp(var_eps).

                                                                      This is where the development becomes pointwise: differentiability depends on positivity of var_eps.

                                                                      Instances For
                                                                        noncomputable def Proofs.Autograd.BatchNormChannelFirst.batchNormPrefixStd {channels height width : } (ε : ) :
                                                                        Graph (ΓBN channels height width) (ssPrefixStd channels height width)

                                                                        Graph prefix computing ssPrefixStd (adds std).

                                                                        Instances For
                                                                          def Proofs.Autograd.BatchNormChannelFirst.idxStd {channels height width : } :
                                                                          Idx (ΓBN channels height width ++ ssPrefixStd channels height width) (VecShape channels)

                                                                          Index of std in ΓBN ++ ssPrefixStd.

                                                                          Instances For
                                                                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeInvStd {channels height width : } :
                                                                            Node (ΓBN channels height width ++ ssPrefixStd channels height width) (VecShape channels)

                                                                            Inverse standard deviation: inv_std := 1/std.

                                                                            Instances For
                                                                              noncomputable def Proofs.Autograd.BatchNormChannelFirst.g8 {channels height width : } (ε : ) :
                                                                              Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels])

                                                                              Graph prefix adding inv_std.

                                                                              Instances For
                                                                                def Proofs.Autograd.BatchNormChannelFirst.idxInvStd {channels height width : } :
                                                                                Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels])) (VecShape channels)

                                                                                Index of inv_std in the extended context after g8.

                                                                                Instances For
                                                                                  noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeInvStdB {channels height width : } :
                                                                                  Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels])) (MatShape channels (hw height width))

                                                                                  Broadcast inv_std back to C×(H*W) (row-wise), producing inv_std_b.

                                                                                  Instances For
                                                                                    noncomputable def Proofs.Autograd.BatchNormChannelFirst.g9 {channels height width : } (ε : ) :
                                                                                    Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width)])

                                                                                    Graph prefix adding inv_std_b.

                                                                                    Instances For
                                                                                      def Proofs.Autograd.BatchNormChannelFirst.idxCentered9 {channels height width : } :
                                                                                      Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                      Index of centered in the context at stage g9.

                                                                                      This is obtained by weakening the earlier idxCentered along the extended intermediate list.

                                                                                      Instances For
                                                                                        def Proofs.Autograd.BatchNormChannelFirst.idxInvStdB9 {channels height width : } :
                                                                                        Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                        Index of inv_std_b in the context at stage g9.

                                                                                        Instances For
                                                                                          noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeNorm {channels height width : } :
                                                                                          Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                          Normalize: normalized := centered ⊙ inv_std_b.

                                                                                          Instances For
                                                                                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.g10 {channels height width : } (ε : ) :
                                                                                            Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)])

                                                                                            Graph prefix adding normalized.

                                                                                            Instances For
                                                                                              def Proofs.Autograd.BatchNormChannelFirst.idxNorm10 {channels height width : } :
                                                                                              Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                              Index of normalized in the extended context at stage g10.

                                                                                              Instances For
                                                                                                noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeGammaB {channels height width : } :
                                                                                                Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                Broadcast gamma : C to C×(H*W) (row-wise), producing gamma_b.

                                                                                                Instances For
                                                                                                  noncomputable def Proofs.Autograd.BatchNormChannelFirst.g11 {channels height width : } (ε : ) :
                                                                                                  Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])

                                                                                                  Graph prefix adding gamma_b.

                                                                                                  Instances For
                                                                                                    def Proofs.Autograd.BatchNormChannelFirst.idxGammaB11 {channels height width : } :
                                                                                                    Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                    Index of gamma_b in the extended context at stage g11.

                                                                                                    Instances For
                                                                                                      noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeScaled {channels height width : } :
                                                                                                      Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                      Scale: scaled := normalized ⊙ gamma_b.

                                                                                                      Instances For
                                                                                                        noncomputable def Proofs.Autograd.BatchNormChannelFirst.g12 {channels height width : } (ε : ) :
                                                                                                        Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])

                                                                                                        Graph prefix adding scaled.

                                                                                                        Instances For
                                                                                                          def Proofs.Autograd.BatchNormChannelFirst.idxScaled12 {channels height width : } :
                                                                                                          Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                          Index of scaled in the extended context at stage g12.

                                                                                                          Instances For
                                                                                                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeBetaB {channels height width : } :
                                                                                                            Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                            Broadcast beta : C to C×(H*W) (row-wise), producing beta_b.

                                                                                                            Instances For
                                                                                                              noncomputable def Proofs.Autograd.BatchNormChannelFirst.g13 {channels height width : } (ε : ) :
                                                                                                              Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])

                                                                                                              Graph prefix adding beta_b.

                                                                                                              Instances For
                                                                                                                def Proofs.Autograd.BatchNormChannelFirst.idxBetaB13 {channels height width : } :
                                                                                                                Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                                Index of beta_b in the extended context at stage g13.

                                                                                                                Instances For
                                                                                                                  noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeYMat {channels height width : } :
                                                                                                                  Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                                  Add bias: yMat := scaled + beta_b.

                                                                                                                  Instances For
                                                                                                                    noncomputable def Proofs.Autograd.BatchNormChannelFirst.g14 {channels height width : } (ε : ) :
                                                                                                                    Graph (ΓBN channels height width) (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])

                                                                                                                    Graph prefix adding yMat.

                                                                                                                    Instances For
                                                                                                                      def Proofs.Autograd.BatchNormChannelFirst.idxYMat {channels height width : } :
                                                                                                                      Idx (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (MatShape channels (hw height width))

                                                                                                                      Index of yMat in the extended context after g14.

                                                                                                                      Instances For
                                                                                                                        noncomputable def Proofs.Autograd.BatchNormChannelFirst.nodeYChw {channels height width : } :
                                                                                                                        Node (ΓBN channels height width ++ (ssPrefixStd channels height width ++ [VecShape channels, MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width), MatShape channels (hw height width)])) (CHWShape channels height width)

                                                                                                                        Reshape the matrix output yMat : C×(H*W) back into yChw : C×H×W.

                                                                                                                        Instances For
                                                                                                                          noncomputable def Proofs.Autograd.BatchNormChannelFirst.batchNormGraph {channels height width : } (ε : ) :
                                                                                                                          Graph (ΓBN channels height width) (ssBatchNorm channels height width)

                                                                                                                          Full BatchNormChannelFirst graph (explicit snoc chain).

                                                                                                                          Instances For
                                                                                                                            noncomputable def Proofs.Autograd.BatchNormChannelFirst.batchNormGraphFderivCorrectAt {channels height width : } (ε : ) (xV : CtxVec (ΓBN channels height width)) (hVarEpsPos : ∀ (i : Fin (VecShape channels).size), 0 < (CtxVec.get idxVarEps ((batchNormPrefixVarEps ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (VecShape channels).size), (CtxVec.get idxStd ((batchNormPrefixStd ε).evalVec xV)).ofLp i 0) :

                                                                                                                            Pointwise proof that batchNormGraph satisfies GraphFDerivCorrectAt.

                                                                                                                            As with LayerNorm, the hypotheses are explicit domain assumptions needed for differentiability of sqrt (after clamp) and inv at the actual execution point.

                                                                                                                            Instances For
                                                                                                                              theorem Proofs.Autograd.BatchNormChannelFirst.backprop_eq_adjoint_fderiv_batchNorm_channel_first_at {channels height width : } (ε : ) (xV : CtxVec (ΓBN channels height width)) (seedV : CtxVec (ΓBN channels height width ++ ssBatchNorm channels height width)) (hVarEpsPos : ∀ (i : Fin (VecShape channels).size), 0 < (CtxVec.get idxVarEps ((batchNormPrefixVarEps ε).evalVec xV)).ofLp i) (hStdNe0 : ∀ (i : Fin (VecShape channels).size), (CtxVec.get idxStd ((batchNormPrefixStd ε).evalVec xV)).ofLp i 0) :

                                                                                                                              Pointwise end-to-end result: backprop equals (fderiv eval)† for batchNormGraph.

                                                                                                                              This is the BatchNormChannelFirst analogue of the global DAG theorem, specialized to the explicit graph construction and with explicit domain assumptions for sqrt/inv.