Intermediate

Model Optimization and Deployment

Lesson 4 of 4 Estimated Time 50 min

Model Optimization and Deployment

Deploying neural networks in production requires more than accuracy—models must be fast, efficient, and fit within resource constraints. This lesson covers quantization, pruning, knowledge distillation, and export techniques to bridge the gap between research models and deployment-ready systems.

Core Concepts

Quantization: Reducing Precision

Quantization reduces numerical precision of model weights and activations, decreasing model size and improving inference speed without significant accuracy loss.

Quantization Schemes:

  • Post-Training Quantization (PTQ): Quantize trained weights after training completes
  • Quantization-Aware Training (QAT): Simulate quantization during training for better accuracy
  • Mixed Precision: Use different precisions for different layers (FP32 for critical, INT8 elsewhere)

Common Bit-widths:

  • INT8: 8-bit integers, 4x model compression
  • INT4: 4-bit integers, 8x compression (more aggressive)
  • FP16: 16-bit floats, 2x compression, maintains reasonable accuracy

Symmetric vs. Asymmetric Quantization:

Symmetric: x_q = round(x / scale)
Asymmetric: x_q = round(x / scale) + zero_point

Pruning: Removing Redundancy

Pruning removes unimportant weights and neurons, reducing model size and computation without retraining from scratch.

Pruning Strategies:

  • Magnitude-based: Remove weights below threshold
  • Gradient-based: Remove weights with small gradients
  • Structured pruning: Remove entire channels/filters
  • Lottery ticket hypothesis: Sparse subnetworks exist from initialization

Practical Impact:

  • 50% pruning: ~1.5x speedup, similar accuracy
  • 90% pruning: ~3-5x speedup, slight accuracy drop
  • Structured > Unstructured for hardware acceleration

Knowledge Distillation

Distillation trains a small “student” model to mimic a large “teacher” model, achieving better accuracy than training small models directly.

Distillation Loss:

L = α * L_CE(student_logits, labels) + (1-α) * L_KL(student_logits, teacher_logits)

Where temperature T softens probability distributions:

P_soft = softmax(logits / T)

Practical Implementation

Quantization-Aware Training

import torch
import torch.nn as nn
from torch.quantization import prepare, convert, QConfig

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# QAT setup
model = SimpleModel()

# Fuse operations for efficiency
torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)

# Prepare for QAT
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = torch.quantization.prepare_qat(model)

# Train with quantization simulation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Convert to quantized model
model = torch.quantization.convert(model, inplace=True)

# Inference
with torch.no_grad():
    output = model(test_data)

Post-Training Quantization

def post_training_quantize(model, calibration_data, num_batches=10):
    model.eval()

    # Prepare for PTQ
    qconfig = torch.quantization.get_default_qconfig('fbgemm')
    model.qconfig = qconfig
    torch.quantization.prepare(model, inplace=True)

    # Calibrate on representative data
    with torch.no_grad():
        for i, (data, _) in enumerate(calibration_data):
            if i >= num_batches:
                break
            _ = model(data)

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    return model

# Usage
quantized_model = post_training_quantize(model, calibration_loader)

# Check model size
import os
torch.save(model.state_dict(), 'model_fp32.pt')
torch.save(quantized_model.state_dict(), 'model_int8.pt')

fp32_size = os.path.getsize('model_fp32.pt') / 1024 / 1024
int8_size = os.path.getsize('model_int8.pt') / 1024 / 1024
print(f'FP32 Model Size: {fp32_size:.2f} MB')
print(f'INT8 Model Size: {int8_size:.2f} MB')
print(f'Compression Ratio: {fp32_size / int8_size:.2f}x')

Pruning Implementation

import torch.nn.utils.prune as prune

class PrunerScheduler:
    def __init__(self, model, pruning_ratio=0.9):
        self.model = model
        self.pruning_ratio = pruning_ratio
        self.parameters_to_prune = [
            (module, 'weight')
            for module in model.modules()
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d)
        ]

    def prune_magnitude(self):
        """Magnitude-based pruning"""
        prune.global_unstructured(
            self.parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=self.pruning_ratio
        )

    def prune_iteratively(self, train_loader, val_loader, num_iterations=10):
        """Iterative pruning with retraining"""
        pruning_amount = self.pruning_ratio ** (1 / num_iterations)

        for iteration in range(num_iterations):
            # Prune slightly
            prune.global_unstructured(
                self.parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=pruning_amount
            )

            # Retrain
            for epoch in range(3):
                train_one_epoch(train_loader)

            # Evaluate
            acc = evaluate(val_loader)
            sparsity = compute_sparsity()
            print(f'Iteration {iteration}, Accuracy: {acc:.4f}, Sparsity: {sparsity:.2%}')

    def remove_pruning_buffers(self):
        """Make pruning permanent"""
        for module, _ in self.parameters_to_prune:
            prune.remove(module, 'weight')

    @staticmethod
    def compute_sparsity(model):
        total = 0
        sparse = 0
        for module in model.modules():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                total += module.weight.numel()
                sparse += (module.weight == 0).sum().item()
        return sparse / total if total > 0 else 0

# Usage
pruner = PrunerScheduler(model, pruning_ratio=0.9)
pruner.prune_magnitude()

Knowledge Distillation

class DistillationTrainer:
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.alpha = alpha
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def train_epoch(self, train_loader, optimizer, criterion):
        self.teacher.eval()
        self.student.train()
        total_loss = 0

        for data, target in train_loader:
            data, target = data.to(self.device), target.to(self.device)

            optimizer.zero_grad()

            # Student prediction
            student_logits = self.student(data)
            student_soft = torch.nn.functional.softmax(student_logits / self.temperature, dim=1)

            # Teacher prediction (no gradients)
            with torch.no_grad():
                teacher_logits = self.teacher(data)
                teacher_soft = torch.nn.functional.softmax(teacher_logits / self.temperature, dim=1)

            # Distillation loss
            kl_loss = torch.nn.functional.kl_div(
                torch.log(student_soft),
                teacher_soft,
                reduction='batchmean'
            )

            # Cross-entropy with true labels
            ce_loss = criterion(student_logits, target)

            # Combined loss
            loss = self.alpha * ce_loss + (1 - self.alpha) * kl_loss * (self.temperature ** 2)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        return total_loss / len(train_loader)

# Usage
teacher = SimpleModel().to(device)
teacher.load_state_dict(torch.load('teacher_weights.pt'))

student = SimpleModel().to(device)

trainer = DistillationTrainer(teacher, student, temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters())
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    loss = trainer.train_epoch(train_loader, optimizer, criterion)
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}')

# Student model is much smaller but performs well
print(f'Teacher params: {sum(p.numel() for p in teacher.parameters()) / 1e6:.2f}M')
print(f'Student params: {sum(p.numel() for p in student.parameters()) / 1e6:.2f}M')

Advanced Techniques

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()

for epoch in range(10):
    for data, target in train_loader:
        optimizer.zero_grad()

        # Forward pass with autocast
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # Backward pass
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

ONNX Export

import onnx
import onnxruntime as ort

# Export to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}},
    opset_version=11
)

# Validate ONNX model
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)

# Inference with ONNX Runtime
sess = ort.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

# Run inference
output = sess.run([output_name], {input_name: dummy_input.numpy()})

TensorRT Optimization

import tensorrt as trt

def build_engine(onnx_file_path, engine_file_path):
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)

    with open(onnx_file_path, 'rb') as f:
        parser.parse(f.read())

    config = builder.create_builder_config()
    config.max_workspace_size = 1 << 28  # 256MB
    config.set_flag(trt.BuilderFlag.FP16)

    engine = builder.build_engine(network, config)
    with open(engine_file_path, 'wb') as f:
        f.write(engine.serialize())

    return engine

engine = build_engine('model.onnx', 'model.trt')

Production Considerations

Benchmarking

import time

def benchmark_model(model, input_size=(1, 1, 28, 28), num_iters=100, device='cuda'):
    model.eval()
    dummy_input = torch.randn(input_size).to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)

    # Benchmark
    torch.cuda.synchronize() if device == 'cuda' else None
    start = time.time()
    with torch.no_grad():
        for _ in range(num_iters):
            _ = model(dummy_input)
    torch.cuda.synchronize() if device == 'cuda' else None
    end = time.time()

    avg_time = (end - start) / num_iters * 1000
    throughput = 1000 / avg_time
    return avg_time, throughput

# Compare original vs optimized
original_time, original_throughput = benchmark_model(original_model)
quantized_time, quantized_throughput = benchmark_model(quantized_model)

print(f'Original: {original_time:.2f}ms ({original_throughput:.0f} inf/s)')
print(f'Quantized: {quantized_time:.2f}ms ({quantized_throughput:.0f} inf/s)')
print(f'Speedup: {original_time / quantized_time:.2f}x')

Memory Analysis

from torch.utils.checkpoint import checkpoint

# Use gradient checkpointing to reduce memory
class CheckpointedModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return checkpoint(self.model, x, use_reentrant=False)

# Monitor memory usage
def measure_memory(model, input_size):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    dummy_input = torch.randn(input_size).cuda()
    with torch.no_grad():
        _ = model(dummy_input)

    peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
    return peak_memory

print(f'Model Memory: {measure_memory(model, (1, 1, 28, 28)):.2f} MB')

Key Takeaway

Production deployment demands more than training accuracy. Quantization, pruning, and distillation let you deploy models with reduced size, faster inference, and lower latency—essential for real-world applications where every millisecond and megabyte count.

Practical Exercise

Task: Build an optimization pipeline: quantize, prune, and distill a ResNet18 model for CIFAR-10.

Requirements:

  1. Train a baseline ResNet18 on CIFAR-10
  2. Apply QAT to reduce to INT8
  3. Implement iterative magnitude-based pruning (target 80% sparsity)
  4. Train a lightweight model via knowledge distillation
  5. Benchmark all variants (original, quantized, pruned, distilled)
  6. Export best model to ONNX

Evaluation:

  • Measure accuracy drop vs. each optimization (target < 2%)
  • Compare inference speed (ms/image)
  • Compare model sizes (MB)
  • Create speedup chart showing trade-offs
  • Analyze which combination (Q+P, Q+D, P+D, Q+P+D) works best