TorchLean API

NN.MLTheory.SelfSupervised.MAE

Masked Autoencoder Objective Semantics #

This module formalizes the finite patch/token core of a masked autoencoder (MAE):

The formalization is intentionally small. It captures the semantics that examples and future model helpers should preserve, while leaving ViT blocks, convolutional patch embeddings, and image IO in the executable API layer.

Paper anchor: “Masked Autoencoders Are Scalable Vision Learners” (He, Chen, Xie, Li, Dollár, Girshick, 2021), arXiv:2111.06377. The key objective-level fact we encode is that the reconstruction loss is taken over the masked patch set. Therefore the objective should not depend on an arbitrary ordering of masked patch indices; maeLoss_reverse is the small finite theorem capturing that property.

@[reducible, inline]

A finite patch collection.

Instances For
    def NN.MLTheory.SelfSupervised.reconstruct {n : } {Patch Pred : Type} (decode : Fin nPredPatch) (pred : Fin nPred) :
    PatchBatch n Patch

    Reconstruct every patch using a reconstruction function.

    Instances For

      Exact reconstruction predicate for all patches.

      Instances For
        def NN.MLTheory.SelfSupervised.maeLoss {n : } {Patch Pred : Type} (maskedIdxs : List (Fin n)) (target : PatchBatch n Patch) (pred : Fin nPred) (patchLoss : PatchPred) :

        MAE-style masked reconstruction loss over an explicit masked-index list.

        The list is the serialized representation of the masked set. The theorems below prove that the objective behaves like a set sum for basic reorderings/decompositions.

        Instances For
          @[simp]
          theorem NN.MLTheory.SelfSupervised.maeLoss_nil {n : } {Patch Pred : Type} (target : PatchBatch n Patch) (pred : Fin nPred) (patchLoss : PatchPred) :
          maeLoss [] target pred patchLoss = 0
          theorem NN.MLTheory.SelfSupervised.maeLoss_append {n : } {Patch Pred : Type} (xs ys : List (Fin n)) (target : PatchBatch n Patch) (pred : Fin nPred) (patchLoss : PatchPred) :
          maeLoss (xs ++ ys) target pred patchLoss = maeLoss xs target pred patchLoss + maeLoss ys target pred patchLoss
          theorem NN.MLTheory.SelfSupervised.maeLoss_reverse {n : } {Patch Pred : Type} (idxs : List (Fin n)) (target : PatchBatch n Patch) (pred : Fin nPred) (patchLoss : PatchPred) :
          maeLoss idxs.reverse target pred patchLoss = maeLoss idxs target pred patchLoss

          The MAE loss is invariant under reversing the order of the masked-index list. This is the small formal version of “masked reconstruction is a set objective, not an ordering objective.”

          theorem NN.MLTheory.SelfSupervised.maeLoss_eq_zero_of_patch_losses_zero {n : } {Patch Pred : Type} (idxs : List (Fin n)) (target : PatchBatch n Patch) (pred : Fin nPred) (patchLoss : PatchPred) (h : ∀ (i : Fin n), i idxspatchLoss (target i) (pred i) = 0) :
          maeLoss idxs target pred patchLoss = 0

          If every selected patch has zero reconstruction loss, the masked MAE loss is zero.

          theorem NN.MLTheory.SelfSupervised.exactReconstruction_identity {n : } {Patch : Type} (x : PatchBatch n Patch) :
          ExactReconstruction x (reconstruct (fun (x : Fin n) (p : Patch) => p) x)

          Reconstructing with the identity decoder/prediction is exact.