Transformer Feed-Forward Sublayer VJP #
This file proves the standard position-wise Transformer feed-forward sublayer, at the vector level:
x ↦ x + W₂ GELU(W₁ x + b₁) + b₂.
The theorem is intentionally about one token/vector. Batched sequence application is a map over positions, and the full Transformer encoder block additionally composes this FFN residual with MHA and LayerNorm. This file gives the clean proof component for the FFN half of that block.
References:
- Vaswani et al., "Attention Is All You Need", NeurIPS 2017.
- Hendrycks and Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415.
Model-vector shape.
Instances For
Hidden feed-forward shape.
Instances For
Context for one position-wise FFN: just the input vector.
Instances For
Saved tensors: first affine, activation, second affine, residual output.
Instances For
Input vector index.
Instances For
Most recently appended tensor helper.
Instances For
First affine output index.
Instances For
GELU activation output index.
Instances For
Second affine output index.
Instances For
Proof-carrying graph for a residual Transformer FFN sublayer.
The two affine maps are fixed LinearSpecs here, so the theorem covers the VJP with respect to the
input vector. Parameter-gradient theorems live at the trainable-parameter/runtime layer.
Instances For
End-to-end VJP theorem for the residual Transformer feed-forward sublayer.
Sequence-shaped FFN residual #
The runtime Transformer applies the same FFN to every token in a (seqLen × dModel) tensor. For the
model-level proof interface below, we package that operation as two fixed affine maps over the
flattened sequence tensor. A concrete shared-weight implementation instantiates these maps with the
usual block-diagonal/time-distributed linear operator; the VJP theorem itself only needs the affine
maps and the smooth GELU primitive.
Sequence-shaped model stream.
Instances For
Sequence-shaped FFN hidden stream.
Instances For
Context for a sequence-level FFN residual block: just the sequence stream.
Instances For
Saved tensors for the sequence-level residual FFN.
Instances For
Sequence input index, weakened through saved tensors.
Instances For
First sequence affine output.
Instances For
GELU activation output.
Instances For
Second sequence affine output.
Instances For
Proof-carrying graph for the sequence-shaped residual FFN:
X ↦ X + A₂(GELU(A₁ X + b₁)) + b₂.
The affine maps are supplied explicitly over flattened sequence tensors. This keeps the theorem usable for shared-weight position-wise FFNs, fused FFN kernels, and future compiler-generated linearizations, as long as they expose the same affine map.
Instances For
End-to-end VJP theorem for the sequence-shaped residual FFN.