Model Optimization and Deployment
Model Optimization and Deployment
Deploying neural networks in production requires more than accuracy—models must be fast, efficient, and fit within resource constraints. This lesson covers quantization, pruning, knowledge distillation, and export techniques to bridge the gap between research models and deployment-ready systems.
Core Concepts
Quantization: Reducing Precision
Quantization reduces numerical precision of model weights and activations, decreasing model size and improving inference speed without significant accuracy loss.
Quantization Schemes:
- Post-Training Quantization (PTQ): Quantize trained weights after training completes
- Quantization-Aware Training (QAT): Simulate quantization during training for better accuracy
- Mixed Precision: Use different precisions for different layers (FP32 for critical, INT8 elsewhere)
Common Bit-widths:
- INT8: 8-bit integers, 4x model compression
- INT4: 4-bit integers, 8x compression (more aggressive)
- FP16: 16-bit floats, 2x compression, maintains reasonable accuracy
Symmetric vs. Asymmetric Quantization:
Symmetric: x_q = round(x / scale)
Asymmetric: x_q = round(x / scale) + zero_point
Pruning: Removing Redundancy
Pruning removes unimportant weights and neurons, reducing model size and computation without retraining from scratch.
Pruning Strategies:
- Magnitude-based: Remove weights below threshold
- Gradient-based: Remove weights with small gradients
- Structured pruning: Remove entire channels/filters
- Lottery ticket hypothesis: Sparse subnetworks exist from initialization
Practical Impact:
- 50% pruning: ~1.5x speedup, similar accuracy
- 90% pruning: ~3-5x speedup, slight accuracy drop
- Structured > Unstructured for hardware acceleration
Knowledge Distillation
Distillation trains a small “student” model to mimic a large “teacher” model, achieving better accuracy than training small models directly.
Distillation Loss:
L = α * L_CE(student_logits, labels) + (1-α) * L_KL(student_logits, teacher_logits)
Where temperature T softens probability distributions:
P_soft = softmax(logits / T)
Practical Implementation
Quantization-Aware Training
import torch
import torch.nn as nn
from torch.quantization import prepare, convert, QConfig
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# QAT setup
model = SimpleModel()
# Fuse operations for efficiency
torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
# Prepare for QAT
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = torch.quantization.prepare_qat(model)
# Train with quantization simulation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to quantized model
model = torch.quantization.convert(model, inplace=True)
# Inference
with torch.no_grad():
output = model(test_data)
Post-Training Quantization
def post_training_quantize(model, calibration_data, num_batches=10):
model.eval()
# Prepare for PTQ
qconfig = torch.quantization.get_default_qconfig('fbgemm')
model.qconfig = qconfig
torch.quantization.prepare(model, inplace=True)
# Calibrate on representative data
with torch.no_grad():
for i, (data, _) in enumerate(calibration_data):
if i >= num_batches:
break
_ = model(data)
# Convert to quantized model
torch.quantization.convert(model, inplace=True)
return model
# Usage
quantized_model = post_training_quantize(model, calibration_loader)
# Check model size
import os
torch.save(model.state_dict(), 'model_fp32.pt')
torch.save(quantized_model.state_dict(), 'model_int8.pt')
fp32_size = os.path.getsize('model_fp32.pt') / 1024 / 1024
int8_size = os.path.getsize('model_int8.pt') / 1024 / 1024
print(f'FP32 Model Size: {fp32_size:.2f} MB')
print(f'INT8 Model Size: {int8_size:.2f} MB')
print(f'Compression Ratio: {fp32_size / int8_size:.2f}x')
Pruning Implementation
import torch.nn.utils.prune as prune
class PrunerScheduler:
def __init__(self, model, pruning_ratio=0.9):
self.model = model
self.pruning_ratio = pruning_ratio
self.parameters_to_prune = [
(module, 'weight')
for module in model.modules()
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d)
]
def prune_magnitude(self):
"""Magnitude-based pruning"""
prune.global_unstructured(
self.parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=self.pruning_ratio
)
def prune_iteratively(self, train_loader, val_loader, num_iterations=10):
"""Iterative pruning with retraining"""
pruning_amount = self.pruning_ratio ** (1 / num_iterations)
for iteration in range(num_iterations):
# Prune slightly
prune.global_unstructured(
self.parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_amount
)
# Retrain
for epoch in range(3):
train_one_epoch(train_loader)
# Evaluate
acc = evaluate(val_loader)
sparsity = compute_sparsity()
print(f'Iteration {iteration}, Accuracy: {acc:.4f}, Sparsity: {sparsity:.2%}')
def remove_pruning_buffers(self):
"""Make pruning permanent"""
for module, _ in self.parameters_to_prune:
prune.remove(module, 'weight')
@staticmethod
def compute_sparsity(model):
total = 0
sparse = 0
for module in model.modules():
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
total += module.weight.numel()
sparse += (module.weight == 0).sum().item()
return sparse / total if total > 0 else 0
# Usage
pruner = PrunerScheduler(model, pruning_ratio=0.9)
pruner.prune_magnitude()
Knowledge Distillation
class DistillationTrainer:
def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
self.alpha = alpha
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train_epoch(self, train_loader, optimizer, criterion):
self.teacher.eval()
self.student.train()
total_loss = 0
for data, target in train_loader:
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad()
# Student prediction
student_logits = self.student(data)
student_soft = torch.nn.functional.softmax(student_logits / self.temperature, dim=1)
# Teacher prediction (no gradients)
with torch.no_grad():
teacher_logits = self.teacher(data)
teacher_soft = torch.nn.functional.softmax(teacher_logits / self.temperature, dim=1)
# Distillation loss
kl_loss = torch.nn.functional.kl_div(
torch.log(student_soft),
teacher_soft,
reduction='batchmean'
)
# Cross-entropy with true labels
ce_loss = criterion(student_logits, target)
# Combined loss
loss = self.alpha * ce_loss + (1 - self.alpha) * kl_loss * (self.temperature ** 2)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
# Usage
teacher = SimpleModel().to(device)
teacher.load_state_dict(torch.load('teacher_weights.pt'))
student = SimpleModel().to(device)
trainer = DistillationTrainer(teacher, student, temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(20):
loss = trainer.train_epoch(train_loader, optimizer, criterion)
print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
# Student model is much smaller but performs well
print(f'Teacher params: {sum(p.numel() for p in teacher.parameters()) / 1e6:.2f}M')
print(f'Student params: {sum(p.numel() for p in student.parameters()) / 1e6:.2f}M')
Advanced Techniques
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
for epoch in range(10):
for data, target in train_loader:
optimizer.zero_grad()
# Forward pass with autocast
with autocast():
output = model(data)
loss = criterion(output, target)
# Backward pass
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
ONNX Export
import onnx
import onnxruntime as ort
# Export to ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}},
opset_version=11
)
# Validate ONNX model
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
# Inference with ONNX Runtime
sess = ort.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
# Run inference
output = sess.run([output_name], {input_name: dummy_input.numpy()})
TensorRT Optimization
import tensorrt as trt
def build_engine(onnx_file_path, engine_file_path):
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
with open(onnx_file_path, 'rb') as f:
parser.parse(f.read())
config = builder.create_builder_config()
config.max_workspace_size = 1 << 28 # 256MB
config.set_flag(trt.BuilderFlag.FP16)
engine = builder.build_engine(network, config)
with open(engine_file_path, 'wb') as f:
f.write(engine.serialize())
return engine
engine = build_engine('model.onnx', 'model.trt')
Production Considerations
Benchmarking
import time
def benchmark_model(model, input_size=(1, 1, 28, 28), num_iters=100, device='cuda'):
model.eval()
dummy_input = torch.randn(input_size).to(device)
# Warmup
with torch.no_grad():
for _ in range(10):
_ = model(dummy_input)
# Benchmark
torch.cuda.synchronize() if device == 'cuda' else None
start = time.time()
with torch.no_grad():
for _ in range(num_iters):
_ = model(dummy_input)
torch.cuda.synchronize() if device == 'cuda' else None
end = time.time()
avg_time = (end - start) / num_iters * 1000
throughput = 1000 / avg_time
return avg_time, throughput
# Compare original vs optimized
original_time, original_throughput = benchmark_model(original_model)
quantized_time, quantized_throughput = benchmark_model(quantized_model)
print(f'Original: {original_time:.2f}ms ({original_throughput:.0f} inf/s)')
print(f'Quantized: {quantized_time:.2f}ms ({quantized_throughput:.0f} inf/s)')
print(f'Speedup: {original_time / quantized_time:.2f}x')
Memory Analysis
from torch.utils.checkpoint import checkpoint
# Use gradient checkpointing to reduce memory
class CheckpointedModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return checkpoint(self.model, x, use_reentrant=False)
# Monitor memory usage
def measure_memory(model, input_size):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
dummy_input = torch.randn(input_size).cuda()
with torch.no_grad():
_ = model(dummy_input)
peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
return peak_memory
print(f'Model Memory: {measure_memory(model, (1, 1, 28, 28)):.2f} MB')
Key Takeaway
Production deployment demands more than training accuracy. Quantization, pruning, and distillation let you deploy models with reduced size, faster inference, and lower latency—essential for real-world applications where every millisecond and megabyte count.
Practical Exercise
Task: Build an optimization pipeline: quantize, prune, and distill a ResNet18 model for CIFAR-10.
Requirements:
- Train a baseline ResNet18 on CIFAR-10
- Apply QAT to reduce to INT8
- Implement iterative magnitude-based pruning (target 80% sparsity)
- Train a lightweight model via knowledge distillation
- Benchmark all variants (original, quantized, pruned, distilled)
- Export best model to ONNX
Evaluation:
- Measure accuracy drop vs. each optimization (target < 2%)
- Compare inference speed (ms/image)
- Compare model sizes (MB)
- Create speedup chart showing trade-offs
- Analyze which combination (Q+P, Q+D, P+D, Q+P+D) works best