TorchLean API

NN.Spec.Models.Gan

Generative adversarial network (GAN) spec #

This module gives TorchLean a small, total GAN interface:

We choose LSGAN as the baseline spec because it is total, compact, and verifier-friendly. Classical logistic GAN losses can be built on top of the same Generator/Discriminator records, but they need additional domain discipline around log.

References:

structure Generative.GAN.Generator (α : Type) (latent obs : Spec.Shape) [Context α] :

Generator G_θ : z ↦ x_fake.

Instances For

    Discriminator/critic D_φ : x ↦ score, represented as a scalar tensor.

    Instances For
      structure Generative.GAN.Model (α : Type) (latent obs : Spec.Shape) [Context α] :

      GAN model pair.

      • generator : Generator α latent obs

        Latent-to-observation generator.

      • discriminator : Discriminator α obs

        Observation-to-score discriminator.

      Instances For
        def Generative.GAN.generate {α : Type} [Context α] {latent obs : Spec.Shape} (model : Model α latent obs) (z : Spec.Tensor α latent) :

        Generate a fake sample.

        Instances For
          def Generative.GAN.fakeScore {α : Type} [Context α] {latent obs : Spec.Shape} (model : Model α latent obs) (z : Spec.Tensor α latent) :

          Discriminator score on a fake sample G(z).

          Instances For
            def Generative.GAN.realScore {α : Type} [Context α] {latent obs : Spec.Shape} (model : Model α latent obs) (x : Spec.Tensor α obs) :

            Discriminator score on a real sample.

            Instances For

              Scalar tensor filled with 1, the LSGAN "real" target.

              Instances For

                Scalar tensor filled with 0, the LSGAN "fake" target.

                Instances For
                  def Generative.GAN.discriminatorLoss {α : Type} [Context α] {latent obs : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α latent obs) (xReal : Spec.Tensor α obs) (z : Spec.Tensor α latent) :
                  α

                  Least-squares discriminator loss:

                  MSE(D(x_real), 1) + MSE(D(G(z)), 0).

                  Instances For
                    def Generative.GAN.generatorLoss {α : Type} [Context α] {latent obs : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α latent obs) (z : Spec.Tensor α latent) :
                    α

                    Least-squares generator loss:

                    MSE(D(G(z)), 1).

                    Instances For
                      @[simp]
                      theorem Generative.GAN.fakeScore_eq_discriminator_generate {α : Type} [Context α] {latent obs : Spec.Shape} (model : Model α latent obs) (z : Spec.Tensor α latent) :

                      Fake scoring expands to discriminator-after-generator.

                      @[simp]
                      theorem Generative.GAN.discriminatorLoss_eq {α : Type} [Context α] {latent obs : Spec.Shape} [DecidableRel fun (x1 x2 : α) => x1 > x2] [LE α] (model : Model α latent obs) (xReal : Spec.Tensor α obs) (z : Spec.Tensor α latent) :

                      The LSGAN discriminator objective is the sum of real and fake score-regression terms.