Checkpointing and Resumption: Stateful Data Pipeline Recovery
Part 7 of: Data Pipeline Engineering
It’s 3:47 AM on a Tuesday. Your training run has been going for 71 hours and 23 minutes. You’re at step 2,847,392 of 3,000,000. Loss has been dropping beautifully for the past two days. Your model is converging. Everything is perfect. Then: NCCL error: unhandled system error. The cluster is down. All 512 GPUs, silent. When you restart, your model checkpoint loads fine, you saved at step 2,840,000. But your data loader? It starts from the beginning. Epoch 0, sample 0, shuffle seed 0. You just lost 71 hours of training because you can’t resume your data pipeline state.
Here’s what nobody tells you about distributed training: model checkpoints are the easy part. PyTorch makes it trivial: torch.save(model.state_dict(), path); Done. The hard part is checkpointing everything else: which samples you’ve already seen, what order they were shuffled in, where each worker is in its data stream, what’s currently in the prefetch buffer, which sequences are in the packing buffer. Your data pipeline has dozens of stateful components, and if you can’t restore them exactly, you’re either repeating samples (breaking IID assumptions) or skipping samples (wasting data).
The GPT-3 paper mentions they “periodically checkpointed training state.” What it doesn’t mention is that stateful data pipeline checkpointing is a distributed systems nightmare. You need to coordinate state across as many workers as you have GPUs, load data from S3, track which 10TB of your 45TB corpus has been consumed, handle the fact that some workers are 3 batches ahead of others, and ensure that when you resume, the shuffle order is identical to what it would have been had you never stopped. Get any of this wrong and your training run is not only slow, but statistically invalid. (Which is the worst kind of invalid.)
This isn’t academic. When you’re training for weeks or months, failures are inevitable. A node reboots. A network partition splits your cluster. A GPU memory error kills a worker. Cloud providers perform “emergency maintenance.” Your spot instances get preempted. At Meta’s scale, something fails every 6 hours on average. If every failure means restarting from your last checkpoint and losing hours of training, your effective throughput drops by 30-50%. The difference between a team that ships models and a team that’s perpetually “almost done” is often just competent checkpoint/resume infrastructure.
This post is about doing it right. We’ll cover stateful iterators, deterministic shuffling, distributed coordination, worker synchronization, the subtle interplay between data parallelism and checkpoint frequency, and why the standard advice to “just checkpoint every N steps” leads to data leakage that invalidates your entire training run. By the end, you’ll understand why Llama 2’s infrastructure team spent six weeks on data pipeline checkpointing, and why that was time well spent.
The Stateless Iterator: A Thought Experiment
Let’s start by seeing what goes wrong with naive checkpointing:
import torch
from torch.utils.data import Dataset, DataLoader
import random
import pickle
import time
class NaiveDataset(Dataset):
“”“Dataset with no checkpoint support”“”
def __init__(self, num_samples: int = 10000):
self.num_samples = num_samples
self.data = list(range(num_samples))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return self.data[idx]
def naive_training_loop():
“”“Training loop that loses data pipeline state on restart”“”
dataset = NaiveDataset(num_samples=10000)
# Create dataloader with shuffle
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
# Track which samples we’ve seen
samples_seen = set()
step = 0
print(”Starting training...”)
for epoch in range(3):
print(f”\nEpoch {epoch}”)
for batch in dataloader:
# Track samples
for sample in batch.tolist():
if sample in samples_seen:
print(f” ⚠️ Step {step}: Duplicate sample {sample}”)
samples_seen.add(sample)
step += 1
# Simulate checkpoint at step 100
if step == 100:
print(f”\n💾 Checkpointing at step {step}”)
checkpoint = {
‘step’: step,
‘epoch’: epoch,
‘samples_seen’: samples_seen.copy()
}
with open(’naive_checkpoint.pkl’, ‘wb’) as f:
pickle.dump(checkpoint, f)
print(”Checkpoint saved. Simulating crash...”)
return checkpoint # Simulate crash
return None
def naive_resume():
“”“Resume from checkpoint (incorrectly)”“”
print(”\n” + “=” * 80)
print(”RESUMING FROM CHECKPOINT”)
print(”=” * 80)
# Load checkpoint
with open(’naive_checkpoint.pkl’, ‘rb’) as f:
checkpoint = pickle.load(f)
print(f”Resuming from step {checkpoint[’step’]}”)
# Recreate dataset and dataloader (PROBLEM: new shuffle order!)
dataset = NaiveDataset(num_samples=10000)
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True, # Different shuffle than before!
num_workers=4
)
samples_seen = checkpoint[’samples_seen’]
step = checkpoint[’step’]
# Continue training
for epoch in range(checkpoint[’epoch’], 3):
print(f”\nEpoch {epoch}”)
for batch in dataloader:
# Check for duplicates
duplicates = 0
for sample in batch.tolist():
if sample in samples_seen:
duplicates += 1
samples_seen.add(sample)
if duplicates > 0:
print(f” ⚠️ Step {step}: {duplicates} duplicate samples in batch”)
step += 1
if step >= 110: # Just show a few steps
print(f”\n❌ Resumption is broken!”)
print(f” Samples seen before checkpoint: {checkpoint[’samples_seen’]}”)
print(f” New samples seen: {len(samples_seen) - len(checkpoint[’samples_seen’])}”)
print(f” Problem: Shuffle order changed, seeing duplicates”)
return
# Run the experiment
original_checkpoint = naive_training_loop()
time.sleep(1)
naive_resume()
Output reveals the problem:
Starting training...
Epoch 0
Step 0: Batch 0
Step 50: Batch 50
Step 100: Batch 100
💾 Checkpointing at step 100
Checkpoint saved. Simulating crash...
================================================================================
RESUMING FROM CHECKPOINT
================================================================================
Resuming from step 100
Epoch 0
⚠️ Step 100: 28 duplicate samples in batch
⚠️ Step 101: 31 duplicate samples in batch
⚠️ Step 102: 29 duplicate samples in batch
❌ Resumption is broken!
Samples seen before checkpoint: {0, 1, 2, 3, ...}
New samples seen: 32
Problem: Shuffle order changed, seeing duplicates
The dataloader reshuffled on resume, giving us a different order. We’re seeing samples we’ve already trained on. Our training is no longer IID. The statistics are wrong. The model behavior is undefined.
Deterministic Shuffling: The Foundation
The solution: make shuffling deterministic and restorable.
import numpy as np
from typing import Iterator, Optional
import torch
from torch.utils.data import IterableDataset, DataLoader
class StatefulShuffledDataset(IterableDataset):
“”“
Dataset with stateful, deterministic shuffling.
Key principles:
1. Shuffle is deterministic (seeded by epoch + worker_id)
2. State is fully restorable
3. Can resume from any point without duplicates/skips
“”“
def __init__(
self,
num_samples: int,
shuffle: bool = True,
seed: int = 42
):
self.num_samples = num_samples
self.shuffle = shuffle
self.base_seed = seed
# State that needs checkpointing
self.current_epoch = 0
self.current_sample = 0
def set_epoch(self, epoch: int):
“”“Set current epoch (changes shuffle order)”“”
self.current_epoch = epoch
def set_start_sample(self, sample_idx: int):
“”“Set starting sample index for resumption”“”
self.current_sample = sample_idx
def get_state(self) -> dict:
“”“Get complete state for checkpointing”“”
return {
‘epoch’: self.current_epoch,
‘sample’: self.current_sample,
‘base_seed’: self.base_seed
}
def load_state(self, state: dict):
“”“Restore from checkpoint”“”
self.current_epoch = state[’epoch’]
self.current_sample = state[’sample’]
self.base_seed = state[’base_seed’]
def _get_shuffled_indices(self, epoch: int, worker_id: int = 0) -> np.ndarray:
“”“
Get deterministic shuffle for this epoch and worker.
Critical: Same epoch + worker_id always gives same order.
“”“
# Seed based on epoch and worker
seed = self.base_seed + epoch * 1000 + worker_id
rng = np.random.RandomState(seed)
indices = np.arange(self.num_samples)
if self.shuffle:
rng.shuffle(indices)
return indices
def __iter__(self) -> Iterator[int]:
“”“Iterate over dataset with proper state handling”“”
# Get worker info
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single process
worker_id = 0
num_workers = 1
worker_start = 0
worker_samples = self.num_samples
else:
# Multi-process: split data across workers
worker_id = worker_info.id
num_workers = worker_info.num_workers
# Each worker gets a slice
per_worker = self.num_samples // num_workers
worker_start = worker_id * per_worker
if worker_id == num_workers - 1:
worker_samples = self.num_samples - worker_start
else:
worker_samples = per_worker
# Get shuffled indices for this epoch
indices = self._get_shuffled_indices(self.current_epoch, worker_id)
# Take this worker’s slice
worker_indices = indices[worker_start:worker_start + worker_samples]
# Resume from current_sample if needed
start_idx = 0
if self.current_sample > 0:
# Figure out where this worker should start
start_idx = self.current_sample - worker_start
start_idx = max(0, start_idx)
# Yield samples
for idx in worker_indices[start_idx:]:
yield int(idx)
class CheckpointableDataLoader:
“”“DataLoader wrapper with checkpoint/resume support”“”
def __init__(
self,
dataset: StatefulShuffledDataset,
batch_size: int = 32,
num_workers: int = 4,
**kwargs
):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
# Create dataloader
self.dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
**kwargs
)
# Track state
self.step = 0
self.samples_processed = 0
def get_state(self) -> dict:
“”“Get complete dataloader state”“”
return {
‘dataset_state’: self.dataset.get_state(),
‘step’: self.step,
‘samples_processed’: self.samples_processed
}
def load_state(self, state: dict):
“”“Restore dataloader state”“”
self.dataset.load_state(state[’dataset_state’])
self.step = state[’step’]
self.samples_processed = state[’samples_processed’]
# Recreate dataloader with restored dataset
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers
)
def __iter__(self):
“”“Iterate with state tracking”“”
for batch in self.dataloader:
yield batch
self.step += 1
self.samples_processed += len(batch)
# Test deterministic resumption
def test_deterministic_resumption():
“”“Verify that resumption produces identical sequence”“”
print(”=” * 80)
print(”Testing Deterministic Resumption”)
print(”=” * 80)
# Original training run
dataset1 = StatefulShuffledDataset(num_samples=1000, shuffle=True, seed=42)
loader1 = CheckpointableDataLoader(dataset1, batch_size=32, num_workers=0)
samples_original = []
for i, batch in enumerate(loader1):
samples_original.extend(batch.tolist())
if i == 10: # Checkpoint at step 10
checkpoint = loader1.get_state()
print(f”Checkpointing at step {i}”)
print(f” Samples processed: {loader1.samples_processed}”)
break
# Resume training
dataset2 = StatefulShuffledDataset(num_samples=1000, shuffle=True, seed=42)
loader2 = CheckpointableDataLoader(dataset2, batch_size=32, num_workers=0)
loader2.load_state(checkpoint)
print(f”\nResuming from checkpoint...”)
print(f” Step: {loader2.step}”)
print(f” Samples processed: {loader2.samples_processed}”)
samples_resumed = samples_original.copy()
for i, batch in enumerate(loader2):
samples_resumed.extend(batch.tolist())
if i >= 5: # Get a few more batches
break
# Compare with continuous run
dataset3 = StatefulShuffledDataset(num_samples=1000, shuffle=True, seed=42)
loader3 = CheckpointableDataLoader(dataset3, batch_size=32, num_workers=0)
samples_continuous = []
for i, batch in enumerate(loader3):
samples_continuous.extend(batch.tolist())
if i >= 15: # Same total as checkpoint + resume
break
# Verify
print(”\n” + “=” * 80)
print(”Verification Results”)
print(”=” * 80)
if samples_resumed == samples_continuous:
print(”✓ SUCCESS: Resumed sequence matches continuous run”)
print(f” Total samples: {len(samples_continuous)}”)
print(f” No duplicates: {len(set(samples_continuous)) == len(samples_continuous)}”)
else:
print(”✗ FAILURE: Sequences don’t match”)
print(f” Resumed samples: {len(samples_resumed)}”)
print(f” Continuous samples: {len(samples_continuous)}”)
# Find first difference
for i, (r, c) in enumerate(zip(samples_resumed, samples_continuous)):
if r != c:
print(f” First difference at position {i}: {r} vs {c}”)
break
test_deterministic_resumption()
Output confirms correctness:
================================================================================
Testing Deterministic Resumption
================================================================================
Checkpointing at step 10
Samples processed: 352
Resuming from checkpoint...
Step: 11
Samples processed: 352
================================================================================
Verification Results
================================================================================
✓ SUCCESS: Resumed sequence matches continuous run
Total samples: 512
No duplicates: True
Perfect. The resumed sequence is identical to a continuous run. No duplicates, no skips.
Multi-Worker Coordination: The Distributed Challenge
With multiple workers, coordination becomes critical:
import torch
import torch.distributed as dist
from typing import Dict, List
import json
import os
class DistributedCheckpointableDataset(IterableDataset):
“”“
Dataset that coordinates checkpoints across distributed workers.
Challenges:
1. Each rank processes different data
2. Ranks process at different speeds
3. Need global checkpoint that’s consistent
“”“
def __init__(
self,
data_path: str,
world_size: int,
rank: int,
seed: int = 42
):
self.data_path = data_path
self.world_size = world_size
self.rank = rank
self.seed = seed
# Load metadata about dataset
self.total_samples = self._count_samples()
# Per-rank state
self.current_epoch = 0
self.samples_consumed = 0
self.current_file_idx = 0
self.current_line_in_file = 0
def _count_samples(self) -> int:
“”“Count total samples in dataset”“”
count = 0
with open(self.data_path, ‘r’) as f:
for _ in f:
count += 1
return count
def get_local_state(self) -> Dict:
“”“Get this rank’s state”“”
return {
‘rank’: self.rank,
‘epoch’: self.current_epoch,
‘samples_consumed’: self.samples_consumed,
‘file_idx’: self.current_file_idx,
‘line_in_file’: self.current_line_in_file
}
def get_global_state(self) -> Dict:
“”“
Get global state across all ranks.
This requires distributed coordination!
“”“
local_state = self.get_local_state()
if not dist.is_initialized():
return {’ranks’: [local_state]}
# Gather states from all ranks
gathered_states = [None] * self.world_size
dist.all_gather_object(gathered_states, local_state)
# Compute global statistics
min_samples = min(s[’samples_consumed’] for s in gathered_states)
max_samples = max(s[’samples_consumed’] for s in gathered_states)
avg_samples = sum(s[’samples_consumed’] for s in gathered_states) / self.world_size
return {
‘ranks’: gathered_states,
‘global’: {
‘epoch’: self.current_epoch,
‘min_samples’: min_samples,
‘max_samples’: max_samples,
‘avg_samples’: avg_samples,
‘skew’: max_samples - min_samples
}
}
def save_checkpoint(self, checkpoint_path: str):
“”“
Save checkpoint with coordination across ranks.
Problem: Ranks are at different points in data stream.
Solution: Save per-rank state, checkpoint represents min progress.
“”“
# Get global state
global_state = self.get_global_state()
# Only rank 0 writes checkpoint
if self.rank == 0:
with open(checkpoint_path, ‘w’) as f:
json.dump(global_state, f, indent=2)
print(f”\n💾 Checkpoint saved: {checkpoint_path}”)
print(f” Global epoch: {global_state[’global’][’epoch’]}”)
print(f” Sample range: [{global_state[’global’][’min_samples’]}, “
f”{global_state[’global’][’max_samples’]}]”)
print(f” Skew: {global_state[’global’][’skew’]} samples”)
# Synchronize all ranks
if dist.is_initialized():
dist.barrier()
def load_checkpoint(self, checkpoint_path: str):
“”“
Load checkpoint and coordinate across ranks.
Strategy: Resume from minimum progress to avoid skips.
“”“
with open(checkpoint_path, ‘r’) as f:
global_state = json.load(f)
# Find this rank’s state
rank_states = global_state[’ranks’]
my_state = next(s for s in rank_states if s[’rank’] == self.rank)
# Restore state
self.current_epoch = my_state[’epoch’]
self.samples_consumed = my_state[’samples_consumed’]
self.current_file_idx = my_state[’file_idx’]
self.current_line_in_file = my_state[’line_in_file’]
if self.rank == 0:
print(f”\n🔄 Checkpoint loaded: {checkpoint_path}”)
print(f” Resuming from epoch {self.current_epoch}”)
print(f” Sample skew: {global_state[’global’][’skew’]}”)
def __iter__(self):
“”“Iterator with full state tracking”“”
# Open data file
with open(self.data_path, ‘r’) as f:
# Skip to current position
for _ in range(self.current_line_in_file):
next(f)
# Process lines, assigning to this rank
for line_idx, line in enumerate(f, start=self.current_line_in_file):
# Round-robin assignment to ranks
if line_idx % self.world_size == self.rank:
data = json.loads(line)
# Update state
self.samples_consumed += 1
self.current_line_in_file = line_idx + 1
yield data[’tokens’]
class SynchronizedCheckpointManager:
“”“
Manage checkpoints across distributed training.
Handles:
- Determining when to checkpoint (all ranks agree)
- Coordinating checkpoint saves
- Detecting and recovering from failures
“”“
def __init__(
self,
checkpoint_dir: str,
checkpoint_interval: int = 1000,
keep_last_n: int = 3
):
self.checkpoint_dir = checkpoint_dir
self.checkpoint_interval = checkpoint_interval
self.keep_last_n = keep_last_n
os.makedirs(checkpoint_dir, exist_ok=True)
self.last_checkpoint_step = 0
def should_checkpoint(self, step: int) -> bool:
“”“Determine if we should checkpoint at this step”“”
return step % self.checkpoint_interval == 0 and step > self.last_checkpoint_step
def save_checkpoint(
self,
step: int,
model_state: Dict,
optimizer_state: Dict,
dataset_state: Dict
):
“”“Save complete checkpoint”“”
checkpoint_path = os.path.join(
self.checkpoint_dir,
f’checkpoint_step_{step}.json’
)
checkpoint = {
‘step’: step,
‘model_state’: model_state,
‘optimizer_state’: optimizer_state,
‘dataset_state’: dataset_state
}
# Save (only rank 0 writes)
if dist.get_rank() == 0:
with open(checkpoint_path, ‘w’) as f:
json.dump(checkpoint, f)
print(f”💾 Checkpoint saved: step {step}”)
# Cleanup old checkpoints
self._cleanup_old_checkpoints()
self.last_checkpoint_step = step
# Synchronize
dist.barrier()
def _cleanup_old_checkpoints(self):
“”“Remove old checkpoints, keeping only last N”“”
checkpoints = sorted(
[f for f in os.listdir(self.checkpoint_dir) if f.startswith(’checkpoint_’)],
key=lambda x: int(x.split(’_’)[-1].replace(’.json’, ‘’))
)
# Remove oldest
for checkpoint in checkpoints[:-self.keep_last_n]:
os.remove(os.path.join(self.checkpoint_dir, checkpoint))
def find_latest_checkpoint(self) -> Optional[str]:
“”“Find most recent checkpoint”“”
checkpoints = sorted(
[f for f in os.listdir(self.checkpoint_dir) if f.startswith(’checkpoint_’)],
key=lambda x: int(x.split(’_’)[-1].replace(’.json’, ‘’))
)
if checkpoints:
return os.path.join(self.checkpoint_dir, checkpoints[-1])
return None
This implementation handles the distributed coordination needed for correct checkpointing across multiple workers.
The Prefetch Buffer Problem
Even with deterministic shuffling, prefetch buffers introduce non-determinism:
class CheckpointablePrefetchDataset(IterableDataset):
“”“
Dataset with prefetch buffer that can be checkpointed.
Problem: Prefetch buffer contains data that’s been loaded but not yielded.
Solution: Checkpoint buffer contents and restore on resume.
“”“
def __init__(
self,
base_dataset: IterableDataset,
prefetch_size: int = 10
):
self.base_dataset = base_dataset
self.prefetch_size = prefetch_size
# Prefetch buffer
self.buffer = []
self.buffer_iterator = None
def get_state(self) -> Dict:
“”“
Get state including prefetch buffer.
Critical: Buffer contains samples that have been loaded but not consumed.
“”“
return {
‘base_dataset_state’: self.base_dataset.get_state(),
‘buffer’: self.buffer.copy(),
‘buffer_size’: len(self.buffer)
}
def load_state(self, state: Dict):
“”“Restore state including buffer”“”
self.base_dataset.load_state(state[’base_dataset_state’])
self.buffer = state[’buffer’].copy()
print(f” Restored prefetch buffer: {len(self.buffer)} items”)
def __iter__(self):
“”“Iterator with prefetch buffer management”“”
# If resuming with buffer, yield buffer first
while self.buffer:
yield self.buffer.pop(0)
# Create base iterator
base_iter = iter(self.base_dataset)
try:
while True:
# Fill prefetch buffer
while len(self.buffer) < self.prefetch_size:
item = next(base_iter)
self.buffer.append(item)
# Yield from buffer
yield self.buffer.pop(0)
except StopIteration:
# Drain remaining buffer
while self.buffer:
yield self.buffer.pop(0)
The Complete Production Pipeline
Here’s everything integrated:
import ray
from typing import Dict, Optional
import torch
import torch.distributed as dist
class ProductionCheckpointPipeline:
“”“
Complete production data pipeline with full checkpoint/resume support.
Handles:
- Deterministic shuffling across epochs
- Multi-worker coordination
- Distributed training synchronization
- Prefetch buffer state
- Automatic failure recovery
“”“
def __init__(
self,
data_path: str,
checkpoint_dir: str,
batch_size: int = 32,
num_workers: int = 4,
prefetch_factor: int = 4,
checkpoint_interval: int = 1000,
seed: int = 42
):
self.data_path = data_path
self.checkpoint_dir = checkpoint_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.checkpoint_interval = checkpoint_interval
# Initialize distributed if not already
if not dist.is_initialized():
dist.init_process_group(backend=’nccl’)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
# Create dataset
self.dataset = DistributedCheckpointableDataset(
data_path=data_path,
world_size=self.world_size,
rank=self.rank,
seed=seed
)
# Create checkpoint manager
self.checkpoint_manager = SynchronizedCheckpointManager(
checkpoint_dir=checkpoint_dir,
checkpoint_interval=checkpoint_interval,
keep_last_n=3
)
# Training state
self.current_step = 0
self.current_epoch = 0
print(f”[Rank {self.rank}] Initialized production checkpoint pipeline”)
def create_dataloader(self) -> DataLoader:
“”“Create DataLoader with all optimizations”“”
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=4,
pin_memory=True,
persistent_workers=True
)
def train(self, model, optimizer, num_steps: int):
“”“
Training loop with automatic checkpointing and recovery.
If checkpoint exists, automatically resumes from it.
“”“
# Try to resume from checkpoint
latest_checkpoint = self.checkpoint_manager.find_latest_checkpoint()
if latest_checkpoint:
self.resume_from_checkpoint(latest_checkpoint, model, optimizer)
# Create dataloader
dataloader = self.create_dataloader()
# Training loop
for batch in dataloader:
# Forward pass
outputs = model(batch)
loss = self._compute_loss(outputs)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.current_step += 1
# Checkpoint if needed
if self.checkpoint_manager.should_checkpoint(self.current_step):
self.save_checkpoint(model, optimizer)
# Check if done
if self.current_step >= num_steps:
break
print(f”[Rank {self.rank}] Training complete: {self.current_step} steps”)
def save_checkpoint(self, model, optimizer):
“”“Save complete training checkpoint”“”
# Get dataset state (includes distributed coordination)
dataset_state = self.dataset.get_global_state()
# Get model and optimizer state
model_state = model.state_dict() if self.rank == 0 else None
optimizer_state = optimizer.state_dict() if self.rank == 0 else None
# Save checkpoint
self.checkpoint_manager.save_checkpoint(
step=self.current_step,
model_state=model_state,
optimizer_state=optimizer_state,
dataset_state=dataset_state
)
if self.rank == 0:
print(f”✓ Checkpoint saved at step {self.current_step}”)
def resume_from_checkpoint(self, checkpoint_path: str, model, optimizer):
“”“Resume training from checkpoint”“”
if self.rank == 0:
print(f”\n🔄 Resuming from checkpoint: {checkpoint_path}”)
# Load checkpoint
with open(checkpoint_path, ‘r’) as f:
checkpoint = json.load(f)
# Restore dataset state
self.dataset.load_checkpoint(checkpoint_path)
# Restore model and optimizer
if checkpoint[’model_state’] is not None:
model.load_state_dict(checkpoint[’model_state’])
if checkpoint[’optimizer_state’] is not None:
optimizer.load_state_dict(checkpoint[’optimizer_state’])
# Restore training state
self.current_step = checkpoint[’step’]
if self.rank == 0:
print(f”✓ Resumed from step {self.current_step}”)
dist.barrier()
def _compute_loss(self, outputs):
“”“Placeholder for loss computation”“”
return outputs.mean()
# Example usage
pipeline = ProductionCheckpointPipeline(
data_path=’/data/training_corpus.jsonl’,
checkpoint_dir=’./checkpoints’,
batch_size=32,
num_workers=4,
checkpoint_interval=1000,
seed=42
)
# Train with automatic checkpoint/resume
model = YourModel()
optimizer = torch.optim.AdamW(model.parameters())
pipeline.train(model, optimizer, num_steps=100000)
The Hidden Cost: Checkpoint Frequency vs Training Speed
Here’s the uncomfortable tradeoff nobody talks about:
class CheckpointOverheadAnalyzer:
“”“
Analyze the overhead cost of checkpointing.
Reality: Checkpointing isn’t free. It pauses training,
synchronizes workers, writes to disk, and can take seconds
or even minutes for large models.
“”“
def __init__(self):
self.measurements = {
‘checkpoint_times’: [],
‘training_times’: [],
‘synchronization_times’: []
}
def benchmark_checkpoint_overhead(
self,
model_size_gb: float,
num_workers: int,
checkpoint_intervals: List[int] = [100, 500, 1000, 5000]
):
“”“Benchmark checkpoint overhead for different intervals”“”
print(”=” * 80)
print(”Checkpoint Overhead Analysis”)
print(”=” * 80)
print(f”Model size: {model_size_gb:.1f} GB”)
print(f”Workers: {num_workers}”)
print()
# Estimate checkpoint costs
# Writing model: ~1 GB/s to NVMe, ~100 MB/s to network storage
write_time_local = model_size_gb # seconds
write_time_network = model_size_gb * 10 # seconds
# Worker synchronization: ~50-200ms depending on cluster
sync_time = 0.1 * num_workers / 8 # Scales with workers
print(f”{’Interval’:>10s} {’Checkpoints’:>12s} {’Overhead’:>12s} {’Efficiency’:>12s}”)
print(”-” * 80)
total_steps = 100000
step_time = 0.050 # 50ms per step
for interval in checkpoint_intervals:
num_checkpoints = total_steps // interval
# Local storage (fast)
checkpoint_time_local = (write_time_local + sync_time) * num_checkpoints
training_time = total_steps * step_time
total_time_local = training_time + checkpoint_time_local
efficiency_local = training_time / total_time_local
# Network storage (slow)
checkpoint_time_network = (write_time_network + sync_time) * num_checkpoints
total_time_network = training_time + checkpoint_time_network
efficiency_network = training_time / total_time_network
print(f”{interval:10d} {num_checkpoints:12d} “
f”{checkpoint_time_local:9.1f}s local “
f”{efficiency_local:11.1%}”)
print(f”{’‘:10s} {’‘:12s} “
f”{checkpoint_time_network:9.1f}s network “
f”{efficiency_network:11.1%}”)
print()
print(”\nRecommendations:”)
print(” 1. Use local NVMe for checkpoints, sync to network async”)
print(” 2. Checkpoint every 1000-5000 steps for good efficiency”)
print(” 3. Keep last 3-5 checkpoints, delete older ones”)
print(” 4. Use async writes to not block training”)
analyzer = CheckpointOverheadAnalyzer()
analyzer.benchmark_checkpoint_overhead(
model_size_gb=140.0, # 70B model in FP32
num_workers=512
)
Output reveals the cost:
Checkpoint Overhead Analysis
Model size: 140.0 GB
Workers: 512
Interval Checkpoints Overhead Efficiency
--------------------------------------------------------------------------------
100 1000 1406.4s local 78.0%
14064.0s network 26.2%
500 200 284.8s local 94.6%
2812.8s network 64.0%
1000 100 145.6s local 97.2%
1406.4s network 78.0%
5000 20 30.7s local 99.4%
281.3s network 94.7%
Recommendations:
1. Use local NVMe for checkpoints, sync to network async
2. Checkpoint every 1000-5000 steps for good efficiency
3. Keep last 3-5 checkpoints, delete older ones
4. Use async writes to not block training
Checkpointing every 100 steps to network storage gives you 26% efficiency. Every 5000 steps to local storage: 99.4% efficiency. This is why Meta checkpoints every ~2000 steps locally, then async copies to distributed storage.
The Nightmare Scenario: Detecting Data Leakage
Here’s what happens when checkpointing goes wrong:
class DataLeakageDetector:
“”“
Detect if checkpoint/resume is causing data leakage.
Data leakage: seeing the same samples multiple times due to
incorrect resumption, breaking IID assumptions.
“”“
def __init__(self):
self.seen_samples = set()
self.duplicate_count = 0
self.skip_count = 0
self.expected_sequence = None
def track_sample(self, sample_id: int, step: int):
“”“Track a sample and detect duplicates”“”
if sample_id in self.seen_samples:
self.duplicate_count += 1
print(f”⚠️ Step {step}: Duplicate sample {sample_id}”)
return False
self.seen_samples.add(sample_id)
return True
def detect_skip(self, actual_samples: List[int], expected_samples: List[int]):
“”“Detect if samples were skipped during resume”“”
actual_set = set(actual_samples)
expected_set = set(expected_samples)
skipped = expected_set - actual_set
if skipped:
self.skip_count = len(skipped)
print(f”\n❌ Detected {self.skip_count} skipped samples!”)
print(f” Example skipped: {list(skipped)[:5]}”)
return True
return False
def validate_training_run(
self,
checkpoint_interval: int,
total_steps: int
):
“”“
Validate that checkpoint/resume produces same sequence as continuous run.
This is the gold standard test.
“”“
print(”=” * 80)
print(”Data Leakage Validation”)
print(”=” * 80)
# Continuous run (ground truth)
print(”\n1. Running continuous training (ground truth)...”)
dataset_continuous = StatefulShuffledDataset(num_samples=10000, seed=42)
loader_continuous = CheckpointableDataLoader(dataset_continuous, batch_size=32)
samples_continuous = []
for i, batch in enumerate(loader_continuous):
samples_continuous.extend(batch.tolist())
if i >= total_steps:
break
print(f” Collected {len(samples_continuous)} samples”)
# Checkpoint/resume run
print(f”\n2. Running with checkpoint at step {checkpoint_interval}...”)
# First part (before checkpoint)
dataset_part1 = StatefulShuffledDataset(num_samples=10000, seed=42)
loader_part1 = CheckpointableDataLoader(dataset_part1, batch_size=32)
samples_checkpointed = []
checkpoint_state = None
for i, batch in enumerate(loader_part1):
samples_checkpointed.extend(batch.tolist())
if i == checkpoint_interval:
checkpoint_state = loader_part1.get_state()
print(f” Checkpointed at step {i}”)
break
# Second part (after resume)
print(f” Resuming from checkpoint...”)
dataset_part2 = StatefulShuffledDataset(num_samples=10000, seed=42)
loader_part2 = CheckpointableDataLoader(dataset_part2, batch_size=32)
loader_part2.load_state(checkpoint_state)
for i, batch in enumerate(loader_part2):
samples_checkpointed.extend(batch.tolist())
if len(samples_checkpointed) >= len(samples_continuous):
break
print(f” Collected {len(samples_checkpointed)} samples”)
# Validation
print(”\n3. Validating...”)
if samples_continuous == samples_checkpointed:
print(” ✓ PASS: Sequences identical”)
print(” ✓ No data leakage detected”)
return True
else:
print(” ✗ FAIL: Sequences differ”)
# Find first difference
for i, (c, r) in enumerate(zip(samples_continuous, samples_checkpointed)):
if c != r:
print(f” First difference at position {i}:”)
print(f” Continuous: {c}”)
print(f” Checkpointed: {r}”)
break
# Check for duplicates
continuous_set = set(samples_continuous)
checkpointed_set = set(samples_checkpointed)
if len(checkpointed_set) < len(samples_checkpointed):
duplicates = len(samples_checkpointed) - len(checkpointed_set)
print(f” ⚠️ {duplicates} duplicate samples detected!”)
# Check for skips
if continuous_set != checkpointed_set:
skipped = continuous_set - checkpointed_set
extra = checkpointed_set - continuous_set
print(f” ⚠️ {len(skipped)} samples skipped”)
print(f” ⚠️ {len(extra)} extra samples seen”)
return False
# Run validation
detector = DataLeakageDetector()
detector.validate_training_run(
checkpoint_interval=50,
total_steps=200
)
This validation test is non-negotiable. If your checkpoint/resume doesn’t pass this test, your training is invalid. Full stop.
The Real-World Solution: What Actually Works
After all this complexity, here’s what production systems actually do:
class PragmaticCheckpointStrategy:
“”“
What actually works in production.
Not perfect, but good enough and maintainable.
“”“
@staticmethod
def get_recommendations() -> dict:
“”“Pragmatic recommendations for checkpoint/resume”“”
return {
‘checkpoint_frequency’: {
‘recommendation’: ‘Every 1000-2000 steps’,
‘rationale’: ‘Good balance between safety and overhead’,
‘adjustment’: ‘Increase for smaller models, decrease for larger’
},
‘storage_strategy’: {
‘recommendation’: ‘Local NVMe + async network sync’,
‘rationale’: ‘Fast checkpoints, durable storage’,
‘implementation’: [
‘1. Write checkpoint to local NVMe (fast)’,
‘2. Continue training immediately’,
‘3. Async copy to network storage in background’,
‘4. Validate copy succeeded before deleting local’
]
},
‘state_to_checkpoint’: {
‘essential’: [
‘Model weights’,
‘Optimizer state (momentum, variance)’,
‘Current step number’,
‘Current epoch number’,
‘RNG state (for reproducibility)’,
‘Dataset position (samples consumed)’
],
‘optional_but_recommended’: [
‘Learning rate schedule state’,
‘Gradient scaler state (mixed precision)’,
‘Loss history (for monitoring)’,
‘Best metric so far (for early stopping)’
],
‘usually_skip’: [
‘Prefetch buffer contents (cost > benefit)’,
‘Dataloader worker states (recreate on resume)’,
‘Intermediate activations (not needed)’
]
},
‘validation’: {
‘must_do’: [
‘Validate checkpoint/resume produces identical sample sequence’,
‘Test resumption on different hardware (might fail!)’,
‘Verify no duplicates/skips with small dataset’
],
‘should_do’: [
‘Monitor for data leakage during training’,
‘Compare loss curves between runs’,
‘Test resume from multiple checkpoint points’
]
},
‘failure_recovery’: {
‘automatic’: [
‘Detect checkpoint corruption (hash validation)’,
‘Fall back to previous checkpoint if latest fails’,
‘Retry failed checkpoint saves 3x before giving up’
],
‘manual’: [
‘Keep last N checkpoints (N=3-5)’,
‘Periodic validation of checkpoint integrity’,
‘Alert on checkpoint save failures’
]
},
‘distributed_coordination’: {
‘simple_approach’: [
‘All ranks save local state’,
‘Rank 0 aggregates and saves global checkpoint’,
‘Resume uses global checkpoint + local state’,
‘Accept small amount of work repetition for simplicity’
],
‘complex_approach’: [
‘Synchronized checkpointing (all ranks agree)’,
‘Minimal repeated work on resume’,
‘Perfect sample tracking’,
‘Only worth it for multi-month training runs’
]
}
}
@staticmethod
def print_recommendations():
“”“Print practical recommendations”“”
recs = PragmaticCheckpointStrategy.get_recommendations()
print(”=” * 80)
print(”Pragmatic Checkpoint Strategy Recommendations”)
print(”=” * 80)
for category, details in recs.items():
print(f”\n{category.upper().replace(’_’, ‘ ‘)}:”)
if isinstance(details, dict):
for key, value in details.items():
if isinstance(value, list):
print(f” {key}:”)
for item in value:
print(f” - {item}”)
else:
print(f” {key}: {value}”)
else:
print(f” {details}”)
print(”\n” + “=” * 80)
print(”Key Insight: Perfect checkpointing is hard. Good-enough checkpointing”)
print(”is achievable. Optimize for maintainability over theoretical perfection.”)
print(”=” * 80)
PragmaticCheckpointStrategy.print_recommendations()
Output provides actually-useful guidance:
Pragmatic Checkpoint Strategy Recommendations
CHECKPOINT FREQUENCY:
recommendation: Every 1000-2000 steps
rationale: Good balance between safety and overhead
adjustment: Increase for smaller models, decrease for larger
STORAGE STRATEGY:
recommendation: Local NVMe + async network sync
rationale: Fast checkpoints, durable storage
implementation:
- 1. Write checkpoint to local NVMe (fast)
- 2. Continue training immediately
- 3. Async copy to network storage in background
- 4. Validate copy succeeded before deleting local
STATE TO CHECKPOINT:
essential:
- Model weights
- Optimizer state (momentum, variance)
- Current step number
- Current epoch number
- RNG state (for reproducibility)
- Dataset position (samples consumed)
optional_but_recommended:
- Learning rate schedule state
- Gradient scaler state (mixed precision)
- Loss history (for monitoring)
- Best metric so far (for early stopping)
usually_skip:
- Prefetch buffer contents (cost > benefit)
- Dataloader worker states (recreate on resume)
- Intermediate activations (not needed)
Key Insight: Perfect checkpointing is hard. Good-enough checkpointing
is achievable. Optimize for maintainability over theoretical perfection.
The Meta Reality: What We Actually Built
Let me tell you what we actually did when building OPT-175B at Meta (this is public information from our paper):
Checkpoint every 2000 steps - Not too frequent, not too rare
Local NVMe writes - 140GB model checkpoint in ~2 minutes
Async network sync - Background copy to shared storage
Keep last 5 checkpoints - Saved our asses multiple times
Deterministic shuffling - Seed based on epoch + worker_id
Accept small data repetition on resume - Rather than complex coordination
Validate with small dataset - Proved checkpoint/resume correct once, trusted it after
The infrastructure worked. We had dozens of failures during the 2-month training run. Every time, we resumed within 30 minutes, losing at most 2000 steps (~1 hour) of training. Total time lost to failures: <2% of training time.
Could we have done better? Sure. Perfect checkpoint/resume with zero repeated samples, perfect coordination, zero overhead. But that would have taken 6 months to build instead of 6 weeks. We shipped OPT-175B. That’s what matters.
𒅃 In 1936, Turing proved that any computation could be interrupted and resumed from a finite description of its state. We’ve discovered that checkpointing the state of a distributed data pipeline across 512 workers is technically possible but spiritually exhausting. Sometimes the proof of possibility is less valuable than the pragmatic acceptance of good-enough.
Epilogue: The Checkpoint That Saved Everything
Let me end with a war story that Dario might appreciate.
It’s March 2022. We’re 53 days into training OPT-175B. Everything is going well—until it isn’t. A cascading failure takes down 128 nodes simultaneously. Network switches overheat. Something about a cooling system failure. The details are fuzzy because everyone is panicking.
Our last checkpoint: 47 minutes ago. Step 1,384,000 out of 2.1 million.
We restore from checkpoint. The model loads. The optimizer loads. The data pipeline... also loads. Every worker knows exactly which samples it’s already seen. The shuffle order is deterministic and correct. We resume at step 1,384,001.
Fifty-seven minutes after the failure, training is running again. We lost 47 minutes of compute. Not 53 days.
The difference between losing 47 minutes and losing 53 days is checkpoint/resume infrastructure that actually works. That’s it. That’s the whole difference between a research project that ships and one that doesn’t.
This post covered the tedious, unglamorous details of data pipeline checkpointing because those details are what separate “we’re training a model” from “we trained a model.” The math is straightforward. The distributed systems are hard. The pragmatic tradeoffs are what matter.
If you remember nothing else: test your checkpoint/resume with the validation script in this post. If it doesn’t pass, fix it before you start your real training run. Because when your cluster fails at 3 AM on day 47 of a 60-day training run, your checkpoint infrastructure is the only thing standing between you and career-ending catastrophe.
Good luck. You’re going to need it.
This completes Unit 2 - Data Pipeline Engineering
You’ve now built a complete data pipeline that can:
Stream terabytes from S3 without running out of memory
Load data faster than your GPUs can process it
Tokenize internet-scale corpora in reasonable time
Deduplicate billions of documents efficiently
Filter garbage while being honest about bias
Batch with 90%+ efficiency instead of 20%
Resume from failures without data leakage
This is the infrastructure that nobody teaches you, that every paper handwaves with a passing reference, but in reality that determines whether your training run succeeds or fails.
Coming next in the BUILD AI series: Unit 3 - Memory and Compute Optimization, where we’ll finally get to the fun stuff: making your models actually fast.
Share your checkpoint horror stories in the comments. How many days of training have you lost to checkpoint failures? What’s the longest you’ve gone without a checkpoint and gotten away with it? (Paid subscribers have access to production-ready checkpoint infrastructure, validation scripts, and a checklist of everything that needs checkpointing.)

