Advanced

Distributed Training and Scaling

Lesson 1 of 4 Estimated Time 60 min

Distributed Training and Scaling

Training large models requires distributed strategies across multiple GPUs and nodes. This lesson covers data parallelism, model parallelism, DeepSpeed ZeRO, FSDP, and gradient accumulation techniques for efficient large-scale training.

Core Concepts

Data Parallelism

Replicate model across GPUs, each processes different batch:

Batch → Split across N GPUs → Compute gradients → Average gradients → Update

Simple DDP (Distributed Data Parallel):

  • Each GPU has full model
  • Synchronize gradients after backward pass
  • Linear scaling with GPUs (up to bandwidth limits)

Model Parallelism

Split model across GPUs:

Layer 1-10 → GPU1 → Layer 11-20 → GPU2 → Layer 21-30 → GPU3

Types:

  • Tensor parallelism: Split tensors within layers
  • Pipeline parallelism: Split layers across stages
  • Sequence parallelism: Split sequence dimension

DeepSpeed ZeRO

Partition optimizer states, gradients, and parameters:

ZeRO-1: Partition optimizer states (4x memory saving)
ZeRO-2: + Partition gradients (8x saving)
ZeRO-3: + Partition parameters (100x+ saving)

FSDP: Fully Sharded Data Parallel

PyTorch implementation of ZeRO, distributes model and optimizer:

  • Memory efficient: O(model_size / num_gpus)
  • Communication overlap: Reduce all-gather latency
  • Flexible sharding strategies

Practical Implementation

DDP Training

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist

# Initialize distributed
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# Model and DDP wrap
model = GPT2(config).to(local_rank)
model = DDP(model, device_ids=[local_rank])

# Distributed sampler
train_sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size())
train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler)

# Training loop
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(epochs):
    train_sampler.set_epoch(epoch)  # Reshuffle data
    
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(local_rank), labels.to(local_rank)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

DeepSpeed Training

import deepspeed

# DeepSpeed config
ds_config = {
    "train_batch_size": 32,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "AdamW",
        "params": {"lr": 2e-5, "betas": [0.9, 0.95]}
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {"device": "cpu"}
    },
    "fp16": {"enabled": True},
    "activation_checkpointing": {"strategy": "entire_transformer"}
}

# Initialize with DeepSpeed
model_engine, optimizer, train_loader, _ = deepspeed.initialize(
    args=None,
    model=model,
    model_parameters=model.parameters(),
    training_data=dataset,
    config=ds_config,
)

# Training
for epoch in range(epochs):
    for batch_idx, batch in enumerate(train_loader):
        loss = model_engine(batch)
        model_engine.backward(loss)
        model_engine.step()

FSDP Training

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import default_auto_wrap_policy

# Wrap model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=default_auto_wrap_policy,
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
    mixed_precision=MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16,
    ),
    device_id=torch.cuda.current_device(),
)

# Training as usual
for batch in train_loader:
    loss = model(batch)
    loss.backward()
    optimizer.step()

Advanced Techniques

Gradient Accumulation

accumulation_steps = 4
optimizer.zero_grad()

for i, batch in enumerate(train_loader):
    outputs = model(batch)
    loss = outputs.loss / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in train_loader:
    with autocast():
        outputs = model(batch)
        loss = outputs.loss
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

Production Considerations

Monitoring Training

import wandb

wandb.init(project="llm-training")

for epoch in range(epochs):
    for batch_idx, batch in enumerate(train_loader):
        loss = train_step(batch)
        
        if batch_idx % 100 == 0:
            wandb.log({
                "loss": loss,
                "epoch": epoch,
                "batch": batch_idx,
                "lr": optimizer.param_groups[0]["lr"],
            })

Key Takeaway

Distributed training with DDP, DeepSpeed, or FSDP enables efficient large-scale model training. Choose based on memory constraints and communication bottlenecks: DDP for many GPUs with fast interconnect, ZeRO-3/FSDP for extreme memory efficiency.

Practical Exercise

Task: Scale training from single GPU to multi-GPU distributed setup.

Requirements:

  1. Implement DDP wrapper
  2. Add DeepSpeed integration
  3. Profile communication overhead
  4. Optimize batch sizes and gradientaccumulation
  5. Benchmark speedup

Evaluation:

  • Linear scaling efficiency > 85%
  • Training time reduction with 8 GPUs
  • Memory efficiency comparison
  • Communication overhead analysis