TorchLean API

NN.MLTheory.SelfSupervised.JEPA

Joint-Embedding Predictive Objective Semantics #

JEPA-style objectives predict target-block representations from context-block representations. This file records the finite-index objective shape without committing to a particular vision backbone, target encoder, or predictor architecture.

Paper anchor: “Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture” (Assran et al., 2023), arXiv:2301.08243. I-JEPA predicts target-block representations from context-block representations rather than reconstructing pixels. The target branch is treated as a target representation at the objective boundary; this is why jepaLoss_target_ext is useful: the loss depends only on target values at the selected target indices.

def NN.MLTheory.SelfSupervised.jepaLoss {n : } {Context Target Pred : Type} (targetIdxs : List (Fin n)) (context : Context) (target : Fin nTarget) (predict : ContextFin nPred) (repLoss : TargetPred) :

JEPA loss over target block indices.

context abstracts the context encoder output, target abstracts target-block representations, and predict abstracts the predictor head. This keeps the objective theorem independent from any particular image backbone.

Instances For
    @[simp]
    theorem NN.MLTheory.SelfSupervised.jepaLoss_nil {n : } {Context Target Pred : Type} (context : Context) (target : Fin nTarget) (predict : ContextFin nPred) (repLoss : TargetPred) :
    jepaLoss [] context target predict repLoss = 0
    theorem NN.MLTheory.SelfSupervised.jepaLoss_append {n : } {Context Target Pred : Type} (xs ys : List (Fin n)) (context : Context) (target : Fin nTarget) (predict : ContextFin nPred) (repLoss : TargetPred) :
    jepaLoss (xs ++ ys) context target predict repLoss = jepaLoss xs context target predict repLoss + jepaLoss ys context target predict repLoss
    theorem NN.MLTheory.SelfSupervised.jepaLoss_reverse {n : } {Context Target Pred : Type} (idxs : List (Fin n)) (context : Context) (target : Fin nTarget) (predict : ContextFin nPred) (repLoss : TargetPred) :
    jepaLoss idxs.reverse context target predict repLoss = jepaLoss idxs context target predict repLoss

    JEPA target-block prediction is invariant under reversing the target-index order.

    theorem NN.MLTheory.SelfSupervised.jepaLoss_target_ext {n : } {Context Target Pred : Type} (idxs : List (Fin n)) (context : Context) (target₁ target₂ : Fin nTarget) (predict : ContextFin nPred) (repLoss : TargetPred) (h : ∀ (i : Fin n), i idxstarget₁ i = target₂ i) :
    jepaLoss idxs context target₁ predict repLoss = jepaLoss idxs context target₂ predict repLoss

    Stop-gradient is modeled at the objective boundary: the target representation is an ordinary value passed into the loss, not an output of the online predictor. This theorem states the corresponding extensional property: if two target branches agree on the selected indices, the JEPA loss is the same.