Intermediate

Generative Models: VAEs and GANs

Lesson 2 of 4 Estimated Time 55 min

Generative Models: VAEs and GANs

Generative models learn the underlying distribution of data and can synthesize new samples. Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) represent two fundamentally different approaches to generative modeling, each with distinct advantages for different applications. This lesson explores their theoretical foundations, training dynamics, and practical implementations.

Core Concepts

Variational Autoencoders (VAEs)

VAEs combine probabilistic modeling with deep learning by learning a latent variable model. Unlike standard autoencoders, VAEs enforce a distribution on the latent space using the variational inference principle.

The VAE Framework:

  • Encoder: Maps input x to a distribution over latent variables z
  • Latent Space: A continuous, probabilistic representation
  • Decoder: Reconstructs x from samples of z
  • KL Divergence: Regularizes the latent distribution toward a standard normal

The VAE loss function combines two objectives:

L_VAE = E[log p(x|z)] - D_KL(q(z|x) || p(z))

Where:

  • First term: reconstruction loss (likelihood)
  • Second term: KL divergence penalty encouraging latent distribution to match prior

Reparameterization Trick: To backpropagate through sampling, we reparameterize:

z = μ + σ ⊙ ε,  where ε ~ N(0, I)

This allows gradients to flow through the sampling operation.

Generative Adversarial Networks (GANs)

GANs employ a game-theoretic approach where two networks compete:

  • Generator (G): Creates fake samples from noise
  • Discriminator (D): Distinguishes real from fake samples

The minimax objective:

min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]

Training Dynamics:

  • Discriminator maximizes ability to classify real vs. fake
  • Generator minimizes discriminator’s ability to distinguish
  • Converges to equilibrium where generator produces authentic samples

Challenges in GAN Training:

  • Mode collapse: Generator converges to producing limited diversity
  • Vanishing gradients: Discriminator becomes too strong
  • Training instability: Loss can oscillate wildly
  • Architecture sensitivity: Small changes drastically affect convergence

Comparing VAEs and GANs

AspectVAEGAN
Loss TypeProbabilistic (ELBO)Adversarial
Latent SpaceContinuous, smoothDiscontinuous
Sample QualityGood, sometimes blurryHighly realistic
Training StabilityMore stableUnstable
Likelihood TractabilityTractable lower boundIntractable
Mode CoverageBetter coverageProne to collapse

Practical Implementation

VAE Implementation in PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class VAEEncoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x.view(x.size(0), -1)))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        return torch.sigmoid(self.fc2(h))

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.encoder = VAEEncoder(input_dim, hidden_dim, latent_dim)
        self.decoder = VAEDecoder(latent_dim, hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss (binary cross-entropy)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # KL divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(latent_dim=20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

transform = transforms.Compose([
    transforms.ToTensor(),
])
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

model.train()
for epoch in range(10):
    total_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(recon_batch, data, mu, logvar)

        loss.backward()
        total_loss += loss.item()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader.dataset):.4f}')

# Generate new samples
with torch.no_grad():
    z = torch.randn(64, 20).to(device)
    samples = model.decoder(z)
    # Visualize samples...

GAN Implementation with DCGAN

import torch
import torch.nn as nn

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=1):
        super().__init__()
        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 128 x 16 x 16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 64 x 32 x 32
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # img_channels x 64 x 64
        )

    def forward(self, z):
        return self.main(z)

class DCGANDiscriminator(nn.Module):
    def __init__(self, img_channels=1):
        super().__init__()
        self.main = nn.Sequential(
            # Input: img_channels x 64 x 64
            nn.Conv2d(img_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 32 x 32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 16 x 16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 8 x 8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 x 4 x 4
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)

# GAN training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = DCGANGenerator(latent_dim=100).to(device)
D = DCGANDiscriminator().to(device)

optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()

real_label = 1.0
fake_label = 0.0

for epoch in range(5):
    for batch_idx, (real_data, _) in enumerate(train_loader):
        real_data = real_data.to(device)
        batch_size = real_data.size(0)

        # Train discriminator
        optimizerD.zero_grad()

        # Real images
        output = D(real_data)
        real_target = torch.full((batch_size, 1), real_label, device=device)
        lossD_real = criterion(output, real_target)

        # Fake images
        z = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_data = G(z)
        output = D(fake_data.detach())
        fake_target = torch.full((batch_size, 1), fake_label, device=device)
        lossD_fake = criterion(output, fake_target)

        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimizerD.step()

        # Train generator
        optimizerG.zero_grad()
        z = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_data = G(z)
        output = D(fake_data)
        target = torch.full((batch_size, 1), real_label, device=device)
        lossG = criterion(output, target)
        lossG.backward()
        optimizerG.step()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, D Loss: {lossD.item():.4f}, G Loss: {lossG.item():.4f}')

Advanced Techniques

Spectral Normalization for Stable GAN Training

Spectral normalization prevents discriminator from becoming too strong:

from torch.nn.utils import spectral_norm

class StableDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = spectral_norm(nn.Conv2d(1, 64, 4, 2, 1))
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
        self.fc = spectral_norm(nn.Linear(128 * 8 * 8, 1))

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = x.view(x.size(0), -1)
        return self.fc(x)

Wasserstein GAN (WGAN) Loss

WGAN uses Earth-Mover distance, providing better training stability:

# Replace sigmoid in discriminator with linear output
# Use -mean(real) + mean(fake) loss instead of BCE

def wgan_discriminator_loss(D_real, D_fake):
    return -torch.mean(D_real) + torch.mean(D_fake)

def wgan_generator_loss(D_fake):
    return -torch.mean(D_fake)

Conditional GANs (cGANs)

Generate images conditioned on class labels:

class ConditionalDCGANGenerator(nn.Module):
    def __init__(self, num_classes=10, latent_dim=100):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, 100)
        # Concatenate class embedding with latent vector
        # Rest same as DCGAN but input dimension is 200

Production Considerations

Evaluation Metrics

For VAEs:

  • Reconstruction loss on test set
  • Likelihood estimate via IWAE (Importance Weighted Autoencoder)
  • Disentanglement metrics (Factor-VAE, Beta-VAE)

For GANs:

  • Inception Score (IS): Higher is better
  • Frechet Inception Distance (FID): Lower is better
  • Precision and Recall: Balanced evaluation
  • Human evaluation: Subjective assessment
from pytorch_fid.fid_score import calculate_fid_given_paths

# Calculate FID between real and generated images
fid_score = calculate_fid_given_paths(
    [real_images_path, fake_images_path],
    batch_size=50,
    device='cuda'
)

Avoiding Mode Collapse

  1. Use experience replay: Store generated samples and mix with new ones
  2. Unrolled GAN: Train discriminator on unrolled generator updates
  3. Feature matching: Use discriminator features for generator loss
  4. Minibatch discrimination: Discriminator looks at batch statistics

Checkpointing and Deployment

# Save best model based on validation metric
if valid_fid < best_fid:
    best_fid = valid_fid
    torch.save({
        'G_state': G.state_dict(),
        'D_state': D.state_dict(),
        'optimizerG': optimizerG.state_dict(),
        'optimizerD': optimizerD.state_dict(),
        'fid': valid_fid
    }, 'best_gan.pt')

# Load for inference
checkpoint = torch.load('best_gan.pt')
G.load_state_dict(checkpoint['G_state'])
G.eval()

with torch.no_grad():
    z = torch.randn(batch_size, 100, 1, 1)
    images = G(z)

Key Takeaway

Generative models enable machines to learn and reproduce data distributions. VAEs offer interpretable latent spaces with stable training, while GANs excel at generating high-quality samples through adversarial competition. Master both approaches to choose the right tool for your specific generative modeling task.

Practical Exercise

Task: Implement a VAE-GAN hybrid model that combines VAE’s stability with GAN’s sample quality. Train it on MNIST or Fashion-MNIST.

Requirements:

  1. Build a VAE that encodes images to latent space
  2. Add a discriminator that distinguishes reconstructed vs. real images
  3. Implement combined loss: VAE loss + adversarial loss with weight λ
  4. Track both reconstruction quality (MSE) and sample quality (via visual inspection)
  5. Generate 16 samples and compare with standard VAE/GAN results
  6. Implement early stopping if KL divergence goes to zero (posterior collapse)

Evaluation:

  • Compare reconstruction error with baseline VAE
  • Visual quality of generated samples
  • Stability metrics during training (smooth loss curves)
  • Ablation: Train with λ=0 (pure VAE) and λ=1 (pure GAN loss) to analyze trade-offs