GraphM Neural Layers #
Normalization and attention builders for proof-compiled graphs.
Layer normalization (sequence-first), producing the same shape as the input.
PyTorch comparison: torch.nn.LayerNorm / torch.nn.functional.layer_norm (modulo exact layout).
Forward-mode status: implemented by Spec.layerNormJvp, including parameter tangents for
gamma and beta.
Instances For
Batch normalization in channel-first layout (no running statistics; spec-level functional form).
PyTorch comparison: torch.nn.BatchNorm2d in NCHW layout (modulo exact semantics/parameters).
Forward-mode status: implemented by Spec.batchNorm2dJvp, including parameter tangents for
gamma and beta.
Instances For
Multi-head attention primitive (shape-specialized).
PyTorch comparison: torch.nn.MultiheadAttention / scaled dot-product attention.
Forward-mode status: implemented by Spec.MultiHeadAttentionJvp, including tangents for the
input and all four projection matrices.