TorchLean API

NN.Proofs.Models.Attention.CausalMask

Causal attention mask laws #

This file proves the exact Boolean semantics of TorchLean's causal and future masks and connects those mask facts to the true hard-masked attention primitive.

TorchLean's main attention spec uses the proof-facing semantics corresponding to scores.masked_fill(~mask, -torch.inf): blocked entries receive zero softmax numerator, hence zero attention mass.

References:

Pointwise access #

The definitions are deliberately simple lower/upper-triangular Boolean tensors, so the access lemmas are definitional. Keeping them as named [simp] theorems lets larger attention proofs use the mask without unfolding the tensor constructors each time.

@[simp]

Reading causalMask n at row i, column j returns exactly j ≤ i.

@[simp]

Reading futureMask n at row i, column j returns exactly i < j.

@[simp]
theorem NN.Proofs.Models.Attention.get2_map2Spec_matrix {α β γ : Type} {m n : } (f : αβγ) (A : Spec.Tensor α (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))) (B : Spec.Tensor β (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))) (i : Fin m) (j : Fin n) :
Spec.get2 (Spec.Tensor.map2Spec f A B) i j = f (Spec.get2 A i j) (Spec.get2 B i j)

Elementwise binary maps commute with matrix indexing.

This small tensor lemma is useful for attention proofs because masking is implemented as map2Spec over the score matrix and the Boolean mask.

Causal blocking and past visibility #

These are the two user-facing laws: causal attention blocks strict future columns and allows every past-or-present column.

A causal mask rejects every strict future key position.

theorem NN.Proofs.Models.Attention.causalMask_allows_past {n : } (i j : Fin n) (hji : j i) :

A causal mask admits every past or current key position.

theorem NN.Proofs.Models.Attention.futureMask_marks_future {n : } (i j : Fin n) (hij : i < j) :

A future mask is the strict complement direction of the causal lower triangle.

A future mask rejects every past or current key position.

Exact hard-mask attention weights #

For hard masking, blocked entries are not merely assigned a very small logit. Their softmax numerator is definitionally zero. These lemmas are the attention-level facts needed for causal non-interference proofs.

Any blocked coordinate of a hard-masked softmax vector has exactly zero weight.

Any blocked coordinate of a row-wise hard-masked softmax matrix has exactly zero weight.

In exact hard-masked causal softmax, every strict-future attention weight is exactly zero.