TorchLean API

NN.Spec.Models.Gnn

Gnn #

GNN models (spec layer).

We provide a compact 2-layer GCN with a graph-level mean pooling readout.

This file is intentionally "model glue": the actual message-passing math lives in NN.Spec.Layers.Gnn (in particular Spec.GCNLayerSpec, Spec.gcn_layer_spec, and Spec.gcn_layer_backward_spec). Here we focus on wiring layers together and documenting the end-to-end shapes.

Reference (GCN):

PyTorch ecosystem analogies:

2-layer GCN with gradients #

Forward is:

  1. Z₁ = GCN₁(X) (linear message passing)
  2. H₁ = ReLU(Z₁)
  3. H₂ = GCN₂(H₁)
  4. y = mean_nodes(H₂) (graph-level readout)

Diagram (single graph, no batching):

X : (n × inDim)
   └─ GCN₁ ─→ Z₁ : (n × hidDim) ─ ReLU ─→ H₁ : (n × hidDim)
                                 └─ GCN₂ ─→ H₂ : (n × outDim)
                                              └─ mean over nodes ─→ y : (outDim)

Backward follows this structure literally:

structure Models.GCN2Spec (n inDim hidDim outDim : ) (α : Type) :

A 2-layer GCN "model spec" for a fixed graph with n nodes.

GCNLayerSpec packages the per-layer parameters (including the adjacency/normalization choice), so the model here is just two such layers composed with a nonlinearity and a readout.

Instances For
    def Models.GCN2Spec.forward {α : Type} [Context α] {n inDim hidDim outDim : } (m : GCN2Spec n inDim hidDim outDim α) (x : Spec.Tensor α (Spec.Shape.dim n (Spec.Shape.dim inDim Spec.Shape.scalar))) (h_n : n > 0) :

    Forward pass for the 2-layer GCN with a graph-level mean pooling readout.

    Input:

    • x : n × inDim node features.

    Output:

    • y : outDim graph embedding produced by averaging node embeddings.

    The h_n : n > 0 assumption is only used to make the mean pooling well-defined (division by n).

    Instances For
      structure Models.GCNLayerGrads (n inDim outDim : ) (α : Type) :

      Per-layer gradients returned by GCNLayerSpec backward.

      This mirrors the tuple returned by Spec.gcn_layer_backward_spec:

      • dA: gradient w.r.t. the adjacency-like operator used by the layer,
      • dW: gradient w.r.t. the weight matrix,
      • db: gradient w.r.t. the bias vector.
      Instances For
        structure Models.GCN2Grads (n inDim hidDim outDim : ) (α : Type) :

        Gradients for both layers of GCN2Spec.

        Instances For
          def Models.GCN2Spec.backward {α : Type} [Context α] [DecidableRel fun (x1 x2 : α) => x1 > x2] {n inDim hidDim outDim : } (m : GCN2Spec n inDim hidDim outDim α) (x : Spec.Tensor α (Spec.Shape.dim n (Spec.Shape.dim inDim Spec.Shape.scalar))) (grad_output : Spec.Tensor α (Spec.Shape.dim outDim Spec.Shape.scalar)) (h_n : n > 0) :
          GCN2Grads n inDim hidDim outDim α × Spec.Tensor α (Spec.Shape.dim n (Spec.Shape.dim inDim Spec.Shape.scalar))

          Backward/VJP for GCN2Spec.forward.

          This is written in the same "spec style" as the rest of TorchLean:

          • recompute small intermediates instead of depending on runtime caches,
          • apply VJPs in reverse order,
          • keep shapes explicit.

          PyTorch analogy: this corresponds to what autograd would do for a graph built from GCNConv → ReLU → GCNConv → global_mean_pool, but expressed as a pure function.

          Instances For