TorchLean API

NN.MLTheory.Generative.Latent.VQVAE

VQ-VAE theory #

VQ-VAE has one mathematically delicate implementation choice: nearest-neighbor code assignment. TorchLean's spec keeps that assignment explicit as a Fin numCodes, so the core codebook semantics are total and easy to audit. Runtime code may compute the index using a CUDA, Python, or Lean argmin; once the index is supplied, the following facts are definitional.

We also prove the real-valued nearest-code optimality lemma used by vector quantization: if an index is selected as an argmin of squared Euclidean distance to the encoder output, then the corresponding quantization loss is minimal among all codebook choices.

Reference:

@[simp]
theorem NN.MLTheory.Generative.Latent.VQVAE.quantized_is_codebook_lookup {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Generative.VQVAE.Model α obs latent numCodes) (idx : Fin numCodes) :

Quantization with an explicit code index is codebook lookup.

@[simp]
theorem NN.MLTheory.Generative.Latent.VQVAE.forward_eq_decoder_codebook {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Generative.VQVAE.Model α obs latent numCodes) (x : Spec.Tensor α obs) (idx : Fin numCodes) :

VQ-VAE reconstruction decodes the selected codebook vector.

@[simp]
theorem NN.MLTheory.Generative.Latent.VQVAE.vqvae_loss_decomposition {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Generative.VQVAE.Model α obs latent numCodes) (beta : α) (x : Spec.Tensor α obs) (idx : Fin numCodes) :

The VQ-VAE loss splits into reconstruction, codebook, and commitment terms.

Connection to the shared latent-objective algebra #

noncomputable def NN.MLTheory.Generative.Latent.VQVAE.vqvaeObjectiveTerms {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : ) => x1 > x2] (model : Generative.VQVAE.Model obs latent numCodes) (x : Spec.Tensor obs) (idx : Fin numCodes) :

Package VQ-VAE reconstruction, codebook, and commitment terms as a weighted three-term objective.

Instances For
    theorem NN.MLTheory.Generative.Latent.VQVAE.vqvae_loss_eq_weightedThreeTerm {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : ) => x1 > x2] (model : Generative.VQVAE.Model obs latent numCodes) (beta : ) (x : Spec.Tensor obs) (idx : Fin numCodes) :

    VQ-VAE loss is exactly the shared base + middle + β * regularizer objective.

    @[simp]
    theorem NN.MLTheory.Generative.Latent.VQVAE.vqvae_loss_zero_beta {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : ) => x1 > x2] (model : Generative.VQVAE.Model obs latent numCodes) (x : Spec.Tensor obs) (idx : Fin numCodes) :

    At commitment weight β = 0, VQ-VAE keeps reconstruction plus codebook loss.

    theorem NN.MLTheory.Generative.Latent.VQVAE.vqvae_loss_eq_reconstruction_of_zero_quantization {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : ) => x1 > x2] (model : Generative.VQVAE.Model obs latent numCodes) (beta : ) (x : Spec.Tensor obs) (idx : Fin numCodes) (hcode : Generative.VQVAE.codebookLoss model x idx = 0) (hcommit : Generative.VQVAE.commitmentLoss model x idx = 0) :

    If the selected code matches the encoder in the sense that both quantization penalties vanish, the VQ-VAE objective reduces to reconstruction loss.

    theorem NN.MLTheory.Generative.Latent.VQVAE.vqvae_loss_mono_beta_of_commitment_nonneg {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : ) => x1 > x2] (model : Generative.VQVAE.Model obs latent numCodes) (x : Spec.Tensor obs) (idx : Fin numCodes) {beta₁ beta₂ : } (hbeta : beta₁ beta₂) (hcommit : 0 Generative.VQVAE.commitmentLoss model x idx) :
    Generative.VQVAE.loss model beta₁ x idx Generative.VQVAE.loss model beta₂ x idx

    Commitment-weight monotonicity for the executable VQ-VAE loss.

    Once a verifier or model-specific theorem establishes that the commitment term is nonnegative, increasing β cannot decrease the objective.

    Nearest-code optimality #

    noncomputable def NN.MLTheory.Generative.Latent.VQVAE.squaredL2 {d : } (x y : Fin d) :

    Squared Euclidean distance on finite real coordinate vectors.

    Instances For
      def NN.MLTheory.Generative.Latent.VQVAE.IsNearestCode {numCodes d : } (embedding : Fin numCodesFin d) (z : Fin d) (idx : Fin numCodes) :

      The predicate that idx is a nearest code for encoder output z under squared Euclidean distance. Ties are allowed; tie-breaking is an implementation detail outside this theorem.

      Instances For

        Squared Euclidean distance is nonnegative.

        @[simp]

        Exact code matches have zero quantization distance.

        theorem NN.MLTheory.Generative.Latent.VQVAE.exactCodeMatch_isNearestCode {numCodes d : } {embedding : Fin numCodesFin d} {z : Fin d} {idx : Fin numCodes} (hmatch : z = embedding idx) :
        IsNearestCode embedding z idx

        If the encoder output is exactly one code, that code is a nearest code.

        theorem NN.MLTheory.Generative.Latent.VQVAE.exactCodeMatch_selected_distance_zero {numCodes d : } {embedding : Fin numCodesFin d} {z : Fin d} {idx : Fin numCodes} (hmatch : z = embedding idx) :
        squaredL2 z (embedding idx) = 0

        Exact code matches have zero selected quantization distance.

        theorem NN.MLTheory.Generative.Latent.VQVAE.nearestCode_minimizes_quantization_loss {numCodes d : } {embedding : Fin numCodesFin d} {z : Fin d} {idx j : Fin numCodes} (hidx : IsNearestCode embedding z idx) :
        squaredL2 z (embedding idx) squaredL2 z (embedding j)

        Nearest-code optimality for VQ-VAE.

        Once the runtime argmin has returned an index satisfying IsNearestCode, the selected code's quantization distance is no larger than the distance to any other code. This is the formal contract that lets CUDA/Python/Lean argmin implementations plug into the same spec semantics.