TorchLean API

NN.Examples.BugZoo.AttentionMask

BugZoo: attention-mask semantics #

Attention code has its own failure modes: mask polarity, head reshaping, Q/K/V layout, and KV-cache mismatches are easy to get wrong and hard to notice from accuracy tests alone. This file focuses on the mask part, because TorchLean already has a precise theorem stack for it.

Here is the bug-shaped PyTorch pattern we want to rule out:

# Wrong for causal attention: the polarity is flipped, so future tokens are allowed.
scores = q @ k.transpose(-2, -1) / math.sqrt(d)
mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
weights = torch.softmax(scores.masked_fill(mask == False, -torch.inf), dim=-1)

The intended PyTorch version uses true negative infinity on blocked future entries:

future = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)
weights = torch.softmax(scores.masked_fill(future, -torch.inf), dim=-1)
assert weights[i, j] == 0.0 for all j > i

Lean's ordinary does not contain a literal -∞, but mathlib does provide extended reals EReal, where is negative infinity and EReal.exp ⊥ = 0. We record that exact -∞ fact first. TorchLean's ordinary tensor softmax then uses the computationally convenient equivalent: blocked logits get zero numerator before normalization. Both views lead to the exact theorem below.

References:

noncomputable def NN.Examples.BugZoo.AttentionMask.exactMaskedLogit (score : ) (allowed : Bool) :

Exact extended-real masked logit: allowed entries keep their real score, blocked entries are literal -∞ (⊥ : EReal).

Instances For
    @[simp]

    Blocking a logit really means assigning -∞ in the extended-real presentation.

    @[simp]

    The key -∞ softmax fact: exp(-∞) = 0.

    Exact extended-real causal masking of one score-matrix coordinate.

    Instances For

      For a strict-future position, exact causal masking assigns literal -∞.

      This is the formal version of the PyTorch operation scores.masked_fill(future, -torch.inf) at one matrix coordinate.

      Therefore, the strict-future numerator is exactly zero.

      This is why TorchLean's attention spec writes this zero numerator directly.

      True--∞ causal attention gets the exact zero-weight theorem.

      This is the statement we want for formal output-causality arguments: every strict-future key has zero attention mass for the current query row. In TorchLean this is represented by hardMaskedSoftmaxSpec, not by a finite real sentinel treated as -∞.