TorchLean API

NN.Spec.Models.VqVae

Vector-quantized VAE (VQ-VAE) spec #

VQ-VAE replaces a continuous latent sample with a discrete codebook lookup. This file exposes the core mechanism in a theorem-friendly way:

  1. an encoder produces a continuous latent z_e(x);
  2. a code index selects a codebook vector z_q;
  3. a decoder reconstructs from z_q;
  4. the loss combines reconstruction, codebook, and commitment terms.

The nearest-neighbor assignment is deliberately an explicit Fin numCodes argument. That keeps the spec total and avoids hiding tie-breaking policy in the mathematical layer; runtime code can compute the index however it likes and then pass the verified index into this spec.

Reference:

structure Generative.VQVAE.Encoder (α : Type) (obs latent : Spec.Shape) [Context α] :

Encoder producing the pre-quantized latent vector z_e(x).

Instances For
    structure Generative.VQVAE.Decoder (α : Type) (latent obs : Spec.Shape) [Context α] :

    Decoder mapping a codebook vector back to observation space.

    Instances For
      structure Generative.VQVAE.Model (α : Type) (obs latent : Spec.Shape) (numCodes : ) [Context α] :

      VQ-VAE model: encoder, codebook, and decoder.

      • encoder : Encoder α obs latent

        Continuous encoder.

      • codebook : Latent.Codebook α numCodes latent

        Finite codebook.

      • decoder : Decoder α latent obs

        Decoder from quantized latent vectors.

      Instances For
        def Generative.VQVAE.encode {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Model α obs latent numCodes) (x : Spec.Tensor α obs) :
        Spec.Tensor α latent

        Pre-quantized latent z_e(x).

        Instances For
          def Generative.VQVAE.quantized {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Model α obs latent numCodes) (idx : Fin numCodes) :
          Spec.Tensor α latent

          Quantized latent z_q, using an explicit code index.

          Instances For
            def Generative.VQVAE.forward {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Model α obs latent numCodes) (_x : Spec.Tensor α obs) (idx : Fin numCodes) :

            VQ-VAE reconstruction from an explicit code assignment.

            Instances For
              def Generative.VQVAE.reconstructionLoss {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent numCodes) (x : Spec.Tensor α obs) (idx : Fin numCodes) :
              α

              Reconstruction term ||dec(z_q)-x||².

              Instances For
                def Generative.VQVAE.codebookLoss {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent numCodes) (x : Spec.Tensor α obs) (idx : Fin numCodes) :
                α

                Codebook term ||z_q-z_e||², written symmetrically at spec level.

                Instances For
                  def Generative.VQVAE.commitmentLoss {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent numCodes) (x : Spec.Tensor α obs) (idx : Fin numCodes) :
                  α

                  Commitment term ||z_e-z_q||², weighted by β in the total objective.

                  Instances For
                    def Generative.VQVAE.loss {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent numCodes) (beta : α) (x : Spec.Tensor α obs) (idx : Fin numCodes) :
                    α

                    VQ-VAE objective: reconstruction + codebook + β commitment.

                    Instances For
                      @[simp]
                      theorem Generative.VQVAE.quantized_eq_embedding {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } (model : Model α obs latent numCodes) (idx : Fin numCodes) :
                      quantized model idx = model.codebook.embedding idx

                      Quantization by explicit index is exactly codebook lookup.

                      @[simp]
                      theorem Generative.VQVAE.loss_eq_reconstruction_add_codebook_add_commitment {α : Type} [Context α] {obs latent : Spec.Shape} {numCodes : } [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent numCodes) (beta : α) (x : Spec.Tensor α obs) (idx : Fin numCodes) :
                      loss model beta x idx = reconstructionLoss model x idx + codebookLoss model x idx + beta * commitmentLoss model x idx

                      The VQ-VAE objective decomposes into the three standard terms.