Foundations

Convolutional Neural Networks

Lesson 3 of 4 Estimated Time 50 min

Convolutional Neural Networks

Convolutional Neural Networks (CNNs) exploit the spatial structure of images. Rather than treating images as flat vectors, CNNs use convolutional filters that preserve spatial relationships, enabling efficient image understanding.

Convolution Operation: The Core Concept

A convolution applies a learnable filter across the image:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Manually compute a single convolution
image = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Add batch and channel dims

# Simple edge detection filter
kernel = torch.tensor([
    [-1, 0, 1],
    [-2, 0, 2],
    [-1, 0, 1]
], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Apply convolution
output = F.conv2d(image, kernel, padding=0, stride=1)
print(f"Input shape: {image.shape}")  # [1, 1, 4, 4]
print(f"Output shape: {output.shape}")  # [1, 1, 2, 2]
print(f"Output values:\n{output}")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 3))

axes[0].imshow(image[0, 0], cmap='gray')
axes[0].set_title('Input Image')

axes[1].imshow(kernel[0, 0], cmap='gray')
axes[1].set_title('Kernel (Edge Detector)')

axes[2].imshow(output[0, 0], cmap='gray')
axes[2].set_title('Output')

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

Key parameters:

  • Kernel size: Filter dimensions (3x3, 5x5, 7x7)
  • Stride: How many pixels the filter moves each step
  • Padding: Zero-padding around image edges
  • Output size: (H - K + 2P) / S + 1 (where H=height, K=kernel, P=padding, S=stride)

Building CNN Architectures

Start simple, then add complexity:

import torch.nn as nn

# Simple CNN: 2 convolutional layers + 2 fully connected
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Flatten and fully connected
        # After two 2x2 pools: 28x28 -> 14x14 -> 7x7
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Conv block 1
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        # Conv block 2
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        # Flatten
        x = x.view(x.size(0), -1)

        # Fully connected
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

# Test
model = SimpleCNN()
x = torch.randn(32, 1, 28, 28)  # Batch of 32 MNIST images
output = model(x)
print(f"Output shape: {output.shape}")  # [32, 10]

Classic Architectures: LeNet and Beyond

LeNet revolutionized computer vision in 1998:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

Modern architectures build on this foundation:

  • AlexNet (2012): Deep networks with ReLU, GPU training
  • VGG (2014): Very deep networks, showed depth matters
  • ResNet (2015): Skip connections solved vanishing gradients
  • Inception: Multi-scale parallel convolutions
  • DenseNet: Dense connections between layers

Batch Normalization: Stabilizing Training

Batch normalization normalizes layer inputs, accelerating convergence:

class CNNWithBatchNorm(nn.Module):
    def __init__(self):
        super(CNNWithBatchNorm, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()

        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(64 * 16 * 16, 10)

    def forward(self, x):
        # Conv -> BatchNorm -> ReLU
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

Batch norm during training uses batch statistics; during inference, uses running statistics.

Transfer Learning: Using Pretrained Models

Training from scratch requires massive data. Transfer learning uses pretrained weights:

from torchvision import models

# Load pretrained ResNet50 (trained on ImageNet)
resnet = models.resnet50(pretrained=True)

# Option 1: Freeze all layers, retrain only last layer
for param in resnet.parameters():
    param.requires_grad = False

# Replace final layer for your task (10 classes instead of ImageNet's 1000)
num_classes = 10
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)

# Only train the new last layer
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=0.001)

# Option 2: Fine-tune (train all layers with low learning rate)
for param in resnet.parameters():
    param.requires_grad = True

# Higher learning rate for new layer, lower for pretrained layers
param_groups = [
    {'params': resnet.fc.parameters(), 'lr': 0.01},
    {'params': [p for p in resnet.parameters() if p not in resnet.fc.parameters()], 'lr': 0.0001}
]
optimizer = torch.optim.Adam(param_groups)

Visualizing What CNNs Learn

Filters in early layers detect edges; deeper layers detect complex patterns:

# Extract first layer filters
first_conv = model.conv1.weight.data  # Shape: [32, 3, 3, 3]

fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < 32:
        # Normalize to [0, 1] for visualization
        filter_img = first_conv[i].permute(1, 2, 0)
        filter_img = (filter_img - filter_img.min()) / (filter_img.max() - filter_img.min())
        ax.imshow(filter_img.cpu().detach().numpy())
        ax.axis('off')

plt.title('First Layer Filters')
plt.tight_layout()
plt.show()

Image Augmentation: Fighting Overfitting

Augmentation artificially increases training data variety:

from torchvision import transforms

# Training augmentations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Test: no augmentation
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 with augmentation
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

Key Takeaway

CNNs leverage spatial structure in images through weight sharing and local connectivity. Transfer learning lets you apply pretrained knowledge to new problems with limited data—a powerful shortcut in practice.

Practical Exercise

Build and train a CNN on CIFAR-10:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Your task:
# 1. Create a CNN with at least 2 conv blocks
# 2. Include batch normalization
# 3. Use data augmentation for training
# 4. Train for 10 epochs with learning rate scheduling
# 5. Evaluate accuracy on test set
# 6. Compare plain CNN vs with batch normalization

# Expected components:
# - Custom CNN class
# - Data augmentation pipeline
# - Learning rate scheduler
# - Training loop with validation
# - Performance comparison
# - Test set accuracy > 70%

# Bonus:
# - Use transfer learning with pretrained ResNet
# - Compare performance vs training from scratch
# - Visualize learned filters

This exercise teaches CNN fundamentals and shows the practical power of architectural choices.