TorchLean API

NN.Spec.Models.Vae

Variational autoencoder (VAE) spec #

This file gives TorchLean a small, backbone-independent VAE interface:

The design mirrors the original VAE formulation of Kingma and Welling (2014), while staying compatible with TorchLean's deterministic spec layer by making the reparameterization noise explicit.

References:

structure Generative.VAE.Encoder (α : Type) (obs latent : Spec.Shape) [Context α] :

Diagonal-Gaussian encoder q_φ(z|x), returning (μ, logσ²).

Instances For
    structure Generative.VAE.Decoder (α : Type) (latent obs : Spec.Shape) [Context α] :

    Decoder/generator p_θ(x|z) represented by its reconstruction mean.

    Instances For
      structure Generative.VAE.Model (α : Type) (obs latent : Spec.Shape) [Context α] :

      A VAE is an encoder plus a decoder.

      • encoder : Encoder α obs latent

        Approximate posterior network.

      • decoder : Decoder α latent obs

        Generative decoder network.

      Instances For
        def Generative.VAE.encode {α : Type} [Context α] {obs latent : Spec.Shape} (model : Model α obs latent) (x : Spec.Tensor α obs) :
        Spec.Tensor α latent × Spec.Tensor α latent

        Encode once and return (μ, logσ²).

        Instances For
          def Generative.VAE.sampleLatent {α : Type} [Context α] {obs latent : Spec.Shape} (model : Model α obs latent) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :
          Spec.Tensor α latent

          Sample the latent using explicit noise.

          Instances For
            def Generative.VAE.forward {α : Type} [Context α] {obs latent : Spec.Shape} (model : Model α obs latent) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :

            Full VAE forward pass: encode, reparameterize with explicit noise, then decode.

            Instances For
              def Generative.VAE.reconstructionLoss {α : Type} [Context α] {obs latent : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :
              α

              Reconstruction term, using mean-squared error in observation space.

              Instances For
                def Generative.VAE.klLoss {α : Type} [Context α] {obs latent : Spec.Shape} (model : Model α obs latent) (x : Spec.Tensor α obs) :
                α

                KL term KL(q_φ(z|x) || N(0,I)), averaged across the latent shape.

                Instances For
                  def Generative.VAE.loss {α : Type} [Context α] {obs latent : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent) (beta : α) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :
                  α

                  β-VAE objective: reconstruction loss plus β times the diagonal-Gaussian KL term.

                  Instances For
                    @[simp]
                    theorem Generative.VAE.forward_eq_decode_reparameterize {α : Type} [Context α] {obs latent : Spec.Shape} (model : Model α obs latent) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :
                    forward model x eps = model.decoder.forward (Latent.reparameterizeDiag (model.encoder.mean x) (model.encoder.logvar x) eps)

                    Forward expansion lemma used by examples and downstream proof files.

                    @[simp]
                    theorem Generative.VAE.loss_eq_reconstruction_add_kl {α : Type} [Context α] {obs latent : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α obs latent) (beta : α) (x : Spec.Tensor α obs) (eps : Spec.Tensor α latent) :
                    loss model beta x eps = reconstructionLoss model x eps + beta * klLoss model x

                    The VAE objective is exactly reconstruction plus a weighted KL term.