16. Distributed Training and FSDP

Bryan Tegomoh

2025-04-06

Welcome to the Lecture

Section III: Foundations for Modern Language Modeling

Focus:

  • Chapter 15: Pretraining Recipes

    • Goal: Understand the choices in pretraining LLMs

    • Builds on: Transformers, Tokenization, Positional encoding

  • Chapter 16: Distributed Training and FSDP

    • Goal: Learn how to scale LLM training

Primary Objective: Understand key concepts for training modern LLMs

Why Distributed Training?

The Scale Challenge: - Modern LLMs: 7B to 1T+ parameters - Single GPU memory (A100): Only 40-80GB - 7B model memory requirements: - Weights: ~14GB (FP16) - Gradients: ~14GB - Optimizer states: ~28GB (Adam) - Activations: Variable with batch size

Solution: Split computation across multiple devices

Hardware Requirements for LLM Training

GPU Memory Requirements

Parallelism Strategies

Four primary approaches:

  1. Data Parallelism:
    • Same model, different data batches
    • Simple but limited by model size
  2. Model Parallelism:
    • Different layers on different devices
    • Enables larger models but sequential processing
  3. Pipeline Parallelism:
    • Stages of model on different devices
    • Multiple batches in pipeline
  4. Tensor Parallelism:
    • Split individual operations (matrices)
    • Efficient for Transformer operations

Improving Optimization Performance with Parallelism Computing

Full Sharded Data Parallelism (FSDP)

Core Concept: Shard model parameters, gradients, and optimizer states

How it works: - Each device stores only a portion of the full model - During forward/backward pass: - Gather needed parameters (all-gather) - Compute with gathered parameters - Re-shard parameters after use

Benefits: - ~N-fold memory reduction with N devices - Enables training much larger models - Preserves computation efficiency of data parallelism

Full Sharded Data Parallelism (FSDP)

Critical Memory Optimizations

Combine these techniques for maximum efficiency:

  • Mixed Precision Training:
    • Use BF16/FP16 for most operations
    • Maintain FP32 master weights for stability
    • 2x memory reduction
  • Activation Checkpointing:
    • Discard activations during forward pass
    • Recompute during backward pass
    • Trade computation for memory
  • Gradient Accumulation:
    • Update less frequently
    • Process more data with same memory

FSDP Implementation (PyTorch)

Basic implementation:

# Initialize distributed process group
torch.distributed.init_process_group("nccl")

# Wrap model with FSDP
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    device_id=torch.cuda.current_device()
)

# Train normally
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()

Hardware & Communication Considerations

Hardware Topology: - Intra-node: High-speed GPU connections (NVLink) - Inter-node: Network connections (InfiniBand)

Communication Bottlenecks: - All-gather: Collect sharded parameters - Reduce-scatter: Aggregate gradients - Communication volume grows with model size

Optimization Strategies: - Match parallelism strategy to hardware topology - Hierarchical communication patterns - Overlap computation with communication

3D Parallelism: State-of-the-Art

Combining complementary strategies:

  • Data Parallelism/FSDP: Across node groups
  • Pipeline Parallelism: Between GPU clusters
  • Tensor Parallelism: Within GPU clusters

Example (175B model training): - 8-way tensor parallelism - 12-way pipeline parallelism - 32-way data parallelism - Total: 3,072 GPUs

Used by major labs for largest models

3D Parallelism

Practical Training Challenges

  • Checkpoint Management:
    • TB-sized checkpoints
    • Distributed saving/loading
  • Fault Tolerance:
    • Hardware failures increasingly likely at scale
    • Need robust recovery mechanisms
  • Performance Debugging:
    • Complex interactions between strategies
    • Profiling distributed workloads
  • Cost Management:
    • Training budgets in millions of dollars
    • Efficiency optimizations have major impact

Frameworks and Resources

Popular Frameworks:

  1. - PyTorch FSDP: Native PyTorch implementation

  2. - DeepSpeed: Microsoft’s ZeRO implementation

  3. - Megatron-LM: NVIDIA’s specialized framework

  4. - JAX/Flax: Google’s functional approach

Learning Resources:

Key Takeaways

  1. Distributed training essential for modern LLMs

  2. FSDP provides memory-efficient data parallelism

  3. Combine multiple parallelism strategies for best results

  4. Memory optimizations as important as compute scaling

  5. Hardware topology should inform parallelism strategy

  6. LLM training represents extreme engineering challenge