Generative Models: VAEs and GANs
Generative Models: VAEs and GANs
Generative models learn the underlying distribution of data and can synthesize new samples. Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) represent two fundamentally different approaches to generative modeling, each with distinct advantages for different applications. This lesson explores their theoretical foundations, training dynamics, and practical implementations.
Core Concepts
Variational Autoencoders (VAEs)
VAEs combine probabilistic modeling with deep learning by learning a latent variable model. Unlike standard autoencoders, VAEs enforce a distribution on the latent space using the variational inference principle.
The VAE Framework:
- Encoder: Maps input x to a distribution over latent variables z
- Latent Space: A continuous, probabilistic representation
- Decoder: Reconstructs x from samples of z
- KL Divergence: Regularizes the latent distribution toward a standard normal
The VAE loss function combines two objectives:
L_VAE = E[log p(x|z)] - D_KL(q(z|x) || p(z))
Where:
- First term: reconstruction loss (likelihood)
- Second term: KL divergence penalty encouraging latent distribution to match prior
Reparameterization Trick: To backpropagate through sampling, we reparameterize:
z = μ + σ ⊙ ε, where ε ~ N(0, I)
This allows gradients to flow through the sampling operation.
Generative Adversarial Networks (GANs)
GANs employ a game-theoretic approach where two networks compete:
- Generator (G): Creates fake samples from noise
- Discriminator (D): Distinguishes real from fake samples
The minimax objective:
min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]
Training Dynamics:
- Discriminator maximizes ability to classify real vs. fake
- Generator minimizes discriminator’s ability to distinguish
- Converges to equilibrium where generator produces authentic samples
Challenges in GAN Training:
- Mode collapse: Generator converges to producing limited diversity
- Vanishing gradients: Discriminator becomes too strong
- Training instability: Loss can oscillate wildly
- Architecture sensitivity: Small changes drastically affect convergence
Comparing VAEs and GANs
| Aspect | VAE | GAN |
|---|---|---|
| Loss Type | Probabilistic (ELBO) | Adversarial |
| Latent Space | Continuous, smooth | Discontinuous |
| Sample Quality | Good, sometimes blurry | Highly realistic |
| Training Stability | More stable | Unstable |
| Likelihood Tractability | Tractable lower bound | Intractable |
| Mode Coverage | Better coverage | Prone to collapse |
Practical Implementation
VAE Implementation in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class VAEEncoder(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = F.relu(self.fc1(x.view(x.size(0), -1)))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
class VAEDecoder(nn.Module):
def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
super().__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = F.relu(self.fc1(z))
return torch.sigmoid(self.fc2(h))
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
self.encoder = VAEEncoder(input_dim, hidden_dim, latent_dim)
self.decoder = VAEDecoder(latent_dim, hidden_dim, input_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
recon = self.decoder(z)
return recon, mu, logvar
def vae_loss(recon_x, x, mu, logvar):
# Reconstruction loss (binary cross-entropy)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# KL divergence loss
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(latent_dim=20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
transform = transforms.Compose([
transforms.ToTensor(),
])
train_loader = DataLoader(
datasets.MNIST('./data', train=True, download=True, transform=transform),
batch_size=128, shuffle=True
)
model.train()
for epoch in range(10):
total_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
loss.backward()
total_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader.dataset):.4f}')
# Generate new samples
with torch.no_grad():
z = torch.randn(64, 20).to(device)
samples = model.decoder(z)
# Visualize samples...
GAN Implementation with DCGAN
import torch
import torch.nn as nn
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, img_channels=1):
super().__init__()
self.main = nn.Sequential(
# Input: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 64 x 32 x 32
nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
nn.Tanh()
# img_channels x 64 x 64
)
def forward(self, z):
return self.main(z)
class DCGANDiscriminator(nn.Module):
def __init__(self, img_channels=1):
super().__init__()
self.main = nn.Sequential(
# Input: img_channels x 64 x 64
nn.Conv2d(img_channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64 x 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128 x 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 256 x 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 512 x 4 x 4
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1, 1)
# GAN training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = DCGANGenerator(latent_dim=100).to(device)
D = DCGANDiscriminator().to(device)
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
real_label = 1.0
fake_label = 0.0
for epoch in range(5):
for batch_idx, (real_data, _) in enumerate(train_loader):
real_data = real_data.to(device)
batch_size = real_data.size(0)
# Train discriminator
optimizerD.zero_grad()
# Real images
output = D(real_data)
real_target = torch.full((batch_size, 1), real_label, device=device)
lossD_real = criterion(output, real_target)
# Fake images
z = torch.randn(batch_size, 100, 1, 1, device=device)
fake_data = G(z)
output = D(fake_data.detach())
fake_target = torch.full((batch_size, 1), fake_label, device=device)
lossD_fake = criterion(output, fake_target)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
# Train generator
optimizerG.zero_grad()
z = torch.randn(batch_size, 100, 1, 1, device=device)
fake_data = G(z)
output = D(fake_data)
target = torch.full((batch_size, 1), real_label, device=device)
lossG = criterion(output, target)
lossG.backward()
optimizerG.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, D Loss: {lossD.item():.4f}, G Loss: {lossG.item():.4f}')
Advanced Techniques
Spectral Normalization for Stable GAN Training
Spectral normalization prevents discriminator from becoming too strong:
from torch.nn.utils import spectral_norm
class StableDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = spectral_norm(nn.Conv2d(1, 64, 4, 2, 1))
self.conv2 = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
self.fc = spectral_norm(nn.Linear(128 * 8 * 8, 1))
def forward(self, x):
x = F.leaky_relu(self.conv1(x), 0.2)
x = F.leaky_relu(self.conv2(x), 0.2)
x = x.view(x.size(0), -1)
return self.fc(x)
Wasserstein GAN (WGAN) Loss
WGAN uses Earth-Mover distance, providing better training stability:
# Replace sigmoid in discriminator with linear output
# Use -mean(real) + mean(fake) loss instead of BCE
def wgan_discriminator_loss(D_real, D_fake):
return -torch.mean(D_real) + torch.mean(D_fake)
def wgan_generator_loss(D_fake):
return -torch.mean(D_fake)
Conditional GANs (cGANs)
Generate images conditioned on class labels:
class ConditionalDCGANGenerator(nn.Module):
def __init__(self, num_classes=10, latent_dim=100):
super().__init__()
self.embedding = nn.Embedding(num_classes, 100)
# Concatenate class embedding with latent vector
# Rest same as DCGAN but input dimension is 200
Production Considerations
Evaluation Metrics
For VAEs:
- Reconstruction loss on test set
- Likelihood estimate via IWAE (Importance Weighted Autoencoder)
- Disentanglement metrics (Factor-VAE, Beta-VAE)
For GANs:
- Inception Score (IS): Higher is better
- Frechet Inception Distance (FID): Lower is better
- Precision and Recall: Balanced evaluation
- Human evaluation: Subjective assessment
from pytorch_fid.fid_score import calculate_fid_given_paths
# Calculate FID between real and generated images
fid_score = calculate_fid_given_paths(
[real_images_path, fake_images_path],
batch_size=50,
device='cuda'
)
Avoiding Mode Collapse
- Use experience replay: Store generated samples and mix with new ones
- Unrolled GAN: Train discriminator on unrolled generator updates
- Feature matching: Use discriminator features for generator loss
- Minibatch discrimination: Discriminator looks at batch statistics
Checkpointing and Deployment
# Save best model based on validation metric
if valid_fid < best_fid:
best_fid = valid_fid
torch.save({
'G_state': G.state_dict(),
'D_state': D.state_dict(),
'optimizerG': optimizerG.state_dict(),
'optimizerD': optimizerD.state_dict(),
'fid': valid_fid
}, 'best_gan.pt')
# Load for inference
checkpoint = torch.load('best_gan.pt')
G.load_state_dict(checkpoint['G_state'])
G.eval()
with torch.no_grad():
z = torch.randn(batch_size, 100, 1, 1)
images = G(z)
Key Takeaway
Generative models enable machines to learn and reproduce data distributions. VAEs offer interpretable latent spaces with stable training, while GANs excel at generating high-quality samples through adversarial competition. Master both approaches to choose the right tool for your specific generative modeling task.
Practical Exercise
Task: Implement a VAE-GAN hybrid model that combines VAE’s stability with GAN’s sample quality. Train it on MNIST or Fashion-MNIST.
Requirements:
- Build a VAE that encodes images to latent space
- Add a discriminator that distinguishes reconstructed vs. real images
- Implement combined loss: VAE loss + adversarial loss with weight λ
- Track both reconstruction quality (MSE) and sample quality (via visual inspection)
- Generate 16 samples and compare with standard VAE/GAN results
- Implement early stopping if KL divergence goes to zero (posterior collapse)
Evaluation:
- Compare reconstruction error with baseline VAE
- Visual quality of generated samples
- Stability metrics during training (smooth loss curves)
- Ablation: Train with λ=0 (pure VAE) and λ=1 (pure GAN loss) to analyze trade-offs