Distributed Training and Scaling
Distributed Training and Scaling
Training large models requires distributed strategies across multiple GPUs and nodes. This lesson covers data parallelism, model parallelism, DeepSpeed ZeRO, FSDP, and gradient accumulation techniques for efficient large-scale training.
Core Concepts
Data Parallelism
Replicate model across GPUs, each processes different batch:
Batch → Split across N GPUs → Compute gradients → Average gradients → Update
Simple DDP (Distributed Data Parallel):
- Each GPU has full model
- Synchronize gradients after backward pass
- Linear scaling with GPUs (up to bandwidth limits)
Model Parallelism
Split model across GPUs:
Layer 1-10 → GPU1 → Layer 11-20 → GPU2 → Layer 21-30 → GPU3
Types:
- Tensor parallelism: Split tensors within layers
- Pipeline parallelism: Split layers across stages
- Sequence parallelism: Split sequence dimension
DeepSpeed ZeRO
Partition optimizer states, gradients, and parameters:
ZeRO-1: Partition optimizer states (4x memory saving)
ZeRO-2: + Partition gradients (8x saving)
ZeRO-3: + Partition parameters (100x+ saving)
FSDP: Fully Sharded Data Parallel
PyTorch implementation of ZeRO, distributes model and optimizer:
- Memory efficient: O(model_size / num_gpus)
- Communication overlap: Reduce all-gather latency
- Flexible sharding strategies
Practical Implementation
DDP Training
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
# Initialize distributed
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Model and DDP wrap
model = GPT2(config).to(local_rank)
model = DDP(model, device_ids=[local_rank])
# Distributed sampler
train_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size())
train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)
# Training loop
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
train_sampler.set_epoch(epoch) # Reshuffle data
for batch in train_loader:
inputs, labels = batch
inputs, labels = inputs.to(local_rank), labels.to(local_rank)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
DeepSpeed Training
import deepspeed
# DeepSpeed config
ds_config = {
"train_batch_size": 32,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {"lr": 2e-5, "betas": [0.9, 0.95]}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"}
},
"fp16": {"enabled": True},
"activation_checkpointing": {"strategy": "entire_transformer"}
}
# Initialize with DeepSpeed
model_engine, optimizer, train_loader, _ = deepspeed.initialize(
args=None,
model=model,
model_parameters=model.parameters(),
training_data=dataset,
config=ds_config,
)
# Training
for epoch in range(epochs):
for batch_idx, batch in enumerate(train_loader):
loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()
FSDP Training
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
# Wrap model with FSDP
model = FSDP(
model,
auto_wrap_policy=default_auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
mixed_precision=MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
),
device_id=torch.cuda.current_device(),
)
# Training as usual
for batch in train_loader:
loss = model(batch)
loss.backward()
optimizer.step()
Advanced Techniques
Gradient Accumulation
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(train_loader):
outputs = model(batch)
loss = outputs.loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in train_loader:
with autocast():
outputs = model(batch)
loss = outputs.loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
Production Considerations
Monitoring Training
import wandb
wandb.init(project="llm-training")
for epoch in range(epochs):
for batch_idx, batch in enumerate(train_loader):
loss = train_step(batch)
if batch_idx % 100 == 0:
wandb.log({
"loss": loss,
"epoch": epoch,
"batch": batch_idx,
"lr": optimizer.param_groups[0]["lr"],
})
Key Takeaway
Distributed training with DDP, DeepSpeed, or FSDP enables efficient large-scale model training. Choose based on memory constraints and communication bottlenecks: DDP for many GPUs with fast interconnect, ZeRO-3/FSDP for extreme memory efficiency.
Practical Exercise
Task: Scale training from single GPU to multi-GPU distributed setup.
Requirements:
- Implement DDP wrapper
- Add DeepSpeed integration
- Profile communication overhead
- Optimize batch sizes and gradientaccumulation
- Benchmark speedup
Evaluation:
- Linear scaling efficiency > 85%
- Training time reduction with 8 GPUs
- Memory efficiency comparison
- Communication overhead analysis