TorchLean API

NN.Runtime.Autograd.Compiled.GraphM.Neural

GraphM Neural Layers #

Normalization and attention builders for proof-compiled graphs.

def Runtime.Autograd.Compiled.GraphM.layerNorm {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {seqLen embedDim : } (h_seq_pos : seqLen > 0) (h_embed_pos : embedDim > 0) (x : Var (Spec.Shape.dim seqLen (Spec.Shape.dim embedDim Spec.Shape.scalar))) (gamma beta : Var (Spec.Shape.dim embedDim Spec.Shape.scalar)) :

Layer normalization (sequence-first), producing the same shape as the input.

PyTorch comparison: torch.nn.LayerNorm / torch.nn.functional.layer_norm (modulo exact layout).

Forward-mode status: implemented by Spec.layerNormJvp, including parameter tangents for gamma and beta.

Instances For
    def Runtime.Autograd.Compiled.GraphM.batchnormChannelFirst {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {channels height width : } (h_c : channels > 0) (h_h : height > 0) (h_w : width > 0) (x : Var (Spec.Shape.dim channels (Spec.Shape.dim height (Spec.Shape.dim width Spec.Shape.scalar)))) (gamma beta : Var (Spec.Shape.dim channels Spec.Shape.scalar)) :

    Batch normalization in channel-first layout (no running statistics; spec-level functional form).

    PyTorch comparison: torch.nn.BatchNorm2d in NCHW layout (modulo exact semantics/parameters).

    Forward-mode status: implemented by Spec.batchNorm2dJvp, including parameter tangents for gamma and beta.

    Instances For
      def Runtime.Autograd.Compiled.GraphM.multiHeadAttention {α Δ : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] [DecidableEq Spec.Shape] {Γ : List Spec.Shape} {n numHeads dModel headDim : } (h1 : n 0) (wq wk wv : Var (Spec.Shape.dim dModel (Spec.Shape.dim (numHeads * headDim) Spec.Shape.scalar))) (wo : Var (Spec.Shape.dim (numHeads * headDim) (Spec.Shape.dim dModel Spec.Shape.scalar))) (x : Var (Spec.Shape.dim n (Spec.Shape.dim dModel Spec.Shape.scalar))) (mask : Option (Spec.Tensor Bool (Spec.Shape.dim n (Spec.Shape.dim n Spec.Shape.scalar))) := none) :

      Multi-head attention primitive (shape-specialized).

      PyTorch comparison: torch.nn.MultiheadAttention / scaled dot-product attention.

      Forward-mode status: implemented by Spec.MultiHeadAttentionJvp, including tangents for the input and all four projection matrices.

      Instances For