BugZoo: attention-mask semantics #
Attention code has its own failure modes: mask polarity, head reshaping, Q/K/V layout, and KV-cache mismatches are easy to get wrong and hard to notice from accuracy tests alone. This file focuses on the mask part, because TorchLean already has a precise theorem stack for it.
Here is the bug-shaped PyTorch pattern we want to rule out:
# Wrong for causal attention: the polarity is flipped, so future tokens are allowed.
scores = q @ k.transpose(-2, -1) / math.sqrt(d)
mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
weights = torch.softmax(scores.masked_fill(mask == False, -torch.inf), dim=-1)
The intended PyTorch version uses true negative infinity on blocked future entries:
future = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
weights = torch.softmax(scores.masked_fill(future, -torch.inf), dim=-1)
assert weights[i, j] == 0.0 for all j > i
Lean's ordinary ℝ does not contain a literal -∞, but mathlib does provide extended reals
EReal, where ⊥ is negative infinity and EReal.exp ⊥ = 0. We record that exact -∞ fact
first. TorchLean's ordinary tensor softmax then uses the computationally convenient equivalent:
blocked logits get zero numerator before normalization. Both views lead to the exact theorem below.
References:
- Vaswani et al., “Attention Is All You Need”, NeurIPS 2017. https://arxiv.org/abs/1706.03762
- PyTorch
scaled_dot_product_attentiondocumentation, for the runtime-style mask interface: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - PyTorch issue #99282, where
MultiheadAttention(is_causal=True)was reported ignored whenneed_weights=True: https://github.com/pytorch/pytorch/issues/99282 - PyTorch issue #160064, where fully masked attention heads were reported to produce NaNs when attention weights were requested: https://github.com/pytorch/pytorch/issues/160064
Blocking a logit really means assigning -∞ in the extended-real presentation.
The key -∞ softmax fact: exp(-∞) = 0.
Exact extended-real causal masking of one score-matrix coordinate.
Instances For
For a strict-future position, exact causal masking assigns literal -∞.
This is the formal version of the PyTorch operation
scores.masked_fill(future, -torch.inf) at one matrix coordinate.
Therefore, the strict-future numerator is exactly zero.
This is why TorchLean's attention spec writes this zero numerator directly.
True--∞ causal attention gets the exact zero-weight theorem.
This is the statement we want for formal output-causality arguments: every strict-future key has
zero attention mass for the current query row. In TorchLean this is represented by
hardMaskedSoftmaxSpec, not by a finite real sentinel treated as -∞.