Gradient Checkpointing: Memory-Compute Tradeoffs
Part 5 of Distributed Training from Scratch
In 1945, John von Neumann described the stored-program computer, where instructions and data shared the same memory space. Eight decades later, his architecture has hit a wall, first with global scale data, then with neural networks. The von Neumann bottleneck, where data movement becomes the limiting factor, breaks down entirely (as the control plane buckles under the strain) trying to meet the massive data requirements of both the cloud’s total global locality and AI workloads. Which is driving the shift toward dataflow architectures that abandon the traditional model entirely. This means we we run out of memory storing activations for backprop before we run out of things to compute. So we invented (in addition to dataflow architecture to address the overall data volume problem) gradient checkpointing, the art of selectively forgetting intermediate results and recomputing them when needed. It's like having a conversation where you deliberately forget half of what was said, confident you can reconstruct it from context when necessary. Just don’t lose your mind.
Welcome to the memory–compute tradeoff, where we exchange wall-clock time for GPU RAM by abandoning von Neumann architecture's insistence that all temporal semantics must flow through a centralized control plane. The constraint of a single, global clock works for modestly sized systems, but enforcing it collapses at scale (much like how relativistic effects only become apparent at vast cosmic distances.) This constraint is not just computationally limiting but ontologically sus, with concrete consequences ranging from limit cases for distributed training to the blockchain synchronization problem. Across different blockchains, for instance, there is no true single history; attempting to impose one leads to the well-known coordination headaches (latency, forks, reorgs.) But again, we're doing AI here. By distributing temporal logic across the computation graph itself, we can train larger models faster, not by hacking the clock, but by letting each node, shard, or computation graph segment schedule its own temporal logic. It's like realizing that memorizing an entire textbook is less efficient than bookmarking key chapters and re-reading them when you actually need the information. (Be sure to bookmark this page.)
The Activation Memory Problem
Here's the brutal math: a transformer layer with sequence length 2048 and hidden dimension 4096 generates roughly 64MB of activations per layer in FP16. Multiply by 80 layers and you're looking at 5GB just for forward pass activations. Add gradients and optimizer states, and your 80GB H100 is suddenly feeling very small.
The traditional solution is "buy more GPUs," which works until you realize you're spending $200K on hardware to store intermediate results you'll only need once during backprop. Gradient checkpointing says: what if we just... didn't store them?
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import time
class ActivationMemoryDemo:
"""Demonstrate the activation memory explosion"""
def __init__(self):
self.memory_stats = []
def measure_activation_memory(self, model, batch_size=16, seq_len=2048):
"""Measure activation memory growth layer by layer"""
x = torch.randn(batch_size, seq_len, model.d_model).cuda()
torch.cuda.reset_peak_memory_stats()
initial_memory = torch.cuda.memory_allocated()
# Forward pass through each layer, measuring memory
for i, layer in enumerate(model.layers):
x = layer(x)
current_memory = torch.cuda.memory_allocated()
layer_memory = (current_memory - initial_memory) / 1e6 # MB
self.memory_stats.append({
'layer': i,
'memory_mb': layer_memory,
'activation_shape': x.shape
})
if i % 10 == 0:
print(f"Layer {i}: {layer_memory:.1f}MB total, {x.shape}")
peak_memory = torch.cuda.max_memory_allocated() / 1e6
print(f"\nTotal peak memory: {peak_memory:.1f}MB")
return peak_memory
# Simple transformer layer for testing
class TransformerLayer(nn.Module):
def __init__(self, d_model=4096, nhead=32, dim_feedforward=16384):
super().__init__()
self.d_model = d_model
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention block
attn_out, _ = self.self_attn(x, x, x)
x = self.norm1(x + attn_out)
# Feed-forward block
ff_out = self.linear2(torch.relu(self.linear1(x)))
x = self.norm2(x + ff_out)
return x
class SimpleTransformer(nn.Module):
def __init__(self, num_layers=20, d_model=4096):
super().__init__()
self.d_model = d_model
self.layers = nn.ModuleList([
TransformerLayer(d_model) for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
# Demo the memory explosion
model = SimpleTransformer(num_layers=20).cuda()
demo = ActivationMemoryDemo()
peak_memory = demo.measure_activation_memory(model)
The numbers are sobering. Each layer accumulates activations that must be kept in memory until backprop reaches that layer. For deep networks, this becomes the dominant memory consumer.
Gradient Checkpointing: Strategic Forgetting
The core insight of gradient checkpointing is simple: instead of storing all intermediate activations, we store only a subset (the "checkpoints") and recompute the rest during backprop. It's a classic time-space tradeoff, but with a twist – modern GPUs are so fast at forward passes that the recomputation often costs less than you'd expect.
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformer(nn.Module):
"""Transformer with gradient checkpointing"""
def __init__(self, num_layers=20, d_model=4096, checkpoint_every=4):
super().__init__()
self.d_model = d_model
self.checkpoint_every = checkpoint_every
# Group layers for checkpointing
self.layer_groups = nn.ModuleList([
nn.ModuleList([
TransformerLayer(d_model)
for _ in range(min(checkpoint_every, num_layers - i))
])
for i in range(0, num_layers, checkpoint_every)
])
def forward(self, x):
for group in self.layer_groups:
# Checkpoint this group of layers
x = checkpoint(self._forward_group, x, group, use_reentrant=False)
return x
def _forward_group(self, x, layer_group):
"""Forward pass through a group of layers"""
for layer in layer_group:
x = layer(x)
return x
class CheckpointingBenchmark:
"""Compare memory usage with and without checkpointing"""
def __init__(self):
pass
def benchmark_memory_usage(self, num_layers=40):
"""Compare checkpointed vs normal forward pass memory"""
batch_size, seq_len, d_model = 8, 1024, 2048
# Normal model
normal_model = SimpleTransformer(num_layers, d_model).cuda()
# Checkpointed model
checkpointed_model = CheckpointedTransformer(num_layers, d_model, checkpoint_every=4).cuda()
x = torch.randn(batch_size, seq_len, d_model).cuda()
# Benchmark normal model
torch.cuda.reset_peak_memory_stats()
_ = normal_model(x)
normal_memory = torch.cuda.max_memory_allocated() / 1e9
# Benchmark checkpointed model
torch.cuda.reset_peak_memory_stats()
_ = checkpointed_model(x)
checkpointed_memory = torch.cuda.max_memory_allocated() / 1e9
savings = (normal_memory - checkpointed_memory) / normal_memory
print(f"Memory Usage Comparison ({num_layers} layers):")
print(f" Normal: {normal_memory:.2f}GB")
print(f" Checkpointed: {checkpointed_memory:.2f}GB")
print(f" Savings: {savings:.1%}")
return savings
def benchmark_training_speed(self, num_layers=20):
"""Compare training speed with and without checkpointing"""
batch_size, seq_len, d_model = 4, 512, 1024
models = {
'normal': SimpleTransformer(num_layers, d_model).cuda(),
'checkpointed': CheckpointedTransformer(num_layers, d_model, checkpoint_every=2).cuda()
}
x = torch.randn(batch_size, seq_len, d_model).cuda()
target = torch.randn(batch_size, seq_len, d_model).cuda()
results = {}
for name, model in models.items():
optimizer = torch.optim.Adam(model.parameters())
# Warmup
for _ in range(5):
optimizer.zero_grad()
output = model(x)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
# Benchmark
start_time = time.perf_counter()
for _ in range(20):
optimizer.zero_grad()
output = model(x)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
total_time = time.perf_counter() - start_time
results[name] = total_time / 20 # Per step
slowdown = results['checkpointed'] / results['normal']
print(f"\nTraining Speed Comparison:")
print(f" Normal: {results['normal']*1000:.1f}ms/step")
print(f" Checkpointed: {results['checkpointed']*1000:.1f}ms/step")
print(f" Slowdown: {slowdown:.2f}x")
return slowdown
benchmark = CheckpointingBenchmark()
memory_savings = benchmark.benchmark_memory_usage(num_layers=40)
speed_cost = benchmark.benchmark_training_speed(num_layers=20)
The tradeoff is clear: significant memory savings at the cost of some training speed. But here's where it gets interesting: the memory savings often allow you to use larger batch sizes, which can more than compensate for the recomputation overhead.
Selective Checkpointing: The Art of Strategic Memory
Not all activations are created equal. Some are expensive to compute and cheap to store (attention outputs). Others are cheap to compute but expensive to store (intermediate feed-forward activations). Smart checkpointing strategies exploit this disparity.
class SelectiveCheckpointer:
"""More sophisticated checkpointing strategy"""
def __init__(self, model: nn.Module):
self.model = model
self.checkpoint_decisions = {}
def analyze_layer_costs(self, sample_input, num_runs=10):
"""Profile computation vs memory cost for each layer"""
layer_stats = {}
x = sample_input
for i, layer in enumerate(self.model.layers):
# Measure computation time
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(num_runs):
_ = layer(x)
torch.cuda.synchronize()
compute_time = (time.perf_counter() - start_time) / num_runs
# Measure memory usage
torch.cuda.reset_peak_memory_stats()
x = layer(x)
memory_used = torch.cuda.max_memory_allocated() / 1e6
# Store statistics
layer_stats[i] = {
'compute_time_ms': compute_time * 1000,
'memory_mb': memory_used,
'memory_per_ms': memory_used / (compute_time * 1000)
}
return layer_stats
def optimize_checkpoint_placement(self, layer_stats, memory_budget_mb=1000):
"""Decide which layers to checkpoint based on cost-benefit"""
# Sort layers by memory/compute ratio (higher = better to checkpoint)
candidates = [(i, stats['memory_per_ms']) for i, stats in layer_stats.items()]
candidates.sort(key=lambda x: x[1], reverse=True)
total_memory_saved = 0
checkpoint_layers = []
for layer_id, ratio in candidates:
layer_memory = layer_stats[layer_id]['memory_mb']
if total_memory_saved + layer_memory <= memory_budget_mb:
checkpoint_layers.append(layer_id)
total_memory_saved += layer_memory
else:
break
print(f"Optimal Checkpointing Strategy:")
print(f" Checkpoint layers: {checkpoint_layers}")
print(f" Memory saved: {total_memory_saved:.1f}MB")
print(f" Additional compute: {sum(layer_stats[i]['compute_time_ms'] for i in checkpoint_layers):.1f}ms")
return checkpoint_layers
class AdaptiveCheckpointWrapper(nn.Module):
"""Wrapper that applies selective checkpointing"""
def __init__(self, model: nn.Module, checkpoint_layers=None):
super().__init__()
self.model = model
self.checkpoint_layers = set(checkpoint_layers or [])
def forward(self, x):
for i, layer in enumerate(self.model.layers):
if i in self.checkpoint_layers:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
# Example usage
model = SimpleTransformer(num_layers=10, d_model=1024).cuda()
optimizer_tool = SelectiveCheckpointer(model)
# Analyze costs
sample_input = torch.randn(4, 256, 1024).cuda()
layer_stats = optimizer_tool.analyze_layer_costs(sample_input)
# Optimize placement
optimal_checkpoints = optimizer_tool.optimize_checkpoint_placement(layer_stats)
# Apply selective checkpointing
optimized_model = AdaptiveCheckpointWrapper(model, optimal_checkpoints)
Production Checkpointing: Managing the Complexity
Of course real-world gradient checkpointing needs to handle edge cases, mixed precision, and distributed training. Here's a production-ready implementation:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper, CheckpointImpl, CheckpointWrapper
)
class ProductionCheckpointer:
"""Production-grade gradient checkpointing"""
def __init__(self, checkpoint_ratio=0.25, mixed_precision=True):
self.checkpoint_ratio = checkpoint_ratio
self.mixed_precision = mixed_precision
def apply_checkpointing(self, model: nn.Module, layer_pattern='layers'):
"""Apply checkpointing to model layers"""
# Find layers to checkpoint
layers = getattr(model, layer_pattern)
num_layers = len(layers)
checkpoint_every = max(1, int(1 / self.checkpoint_ratio))
print(f"Applying checkpointing every {checkpoint_every} layers ({num_layers} total)")
# Wrap selected layers
for i in range(0, num_layers, checkpoint_every):
if i < num_layers:
layers[i] = checkpoint_wrapper(
layers[i],
checkpoint_impl=CheckpointImpl.NO_REENTRANT
)
return model
def memory_efficient_training_step(self, model, batch, targets, optimizer, scaler=None):
"""Training step optimized for checkpointing"""
optimizer.zero_grad()
if self.mixed_precision and scaler:
from torch.cuda.amp import autocast
with autocast():
outputs = model(batch)
loss = nn.CrossEntropyLoss()(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(batch)
loss = nn.CrossEntropyLoss()(outputs, targets)
loss.backward()
optimizer.step()
return loss.item()
# Quick demo
model = SimpleTransformer(num_layers=16, d_model=1024).cuda()
checkpointer = ProductionCheckpointer(checkpoint_ratio=0.25)
checkpointed_model = checkpointer.apply_checkpointing(model)
print("Gradient checkpointing applied successfully!")
The Memory-Speed Paradox
Here's the counterintuitive part: gradient checkpointing often makes training faster overall, not just more memory-efficient. By reducing memory pressure, you can use larger batch sizes, which improves GPU utilization and can more than compensate for the recomputation overhead. To add to our pile of similes, it's like discovering that deliberately forgetting ‘half’ your notes makes you learn faster because you can focus on more important material. And the professor still picks your final as best-in-class.
The sweet spot is typically checkpointing every 2-4 layers, giving you 40-60% memory savings with only 10-20% compute overhead. But your mileage will vary based on model architecture, hardware, and batch size constraints.
The Philosophy of Forgetting
Gradient checkpointing embodies a fascinating principle: strategic forgetting as a form of optimization. We're not just saving memory; we're acknowledging that perfect recall isn't necessary for effective learning. (Tho I can also assure you that perfect recall is very nice to have, personally.) The neural network doesn't care that we forgot the intermediate activations, it only cares that we can reconstruct the gradients it needs. The gradients you needed and gradients you only thought you needed.
This connects to broader themes in intelligence and memory. Human brains don't store every intermediate thought during reasoning; they maintain only the essential information needed to continue the cognitive process. Gradient checkpointing is our artificial version of this selective attention.
The Practical Playbook
When to Use Gradient Checkpointing:
Deep models (>20 layers) that don't fit in memory
When you want larger batch sizes more than raw speed
Training very long sequences (>4K tokens)
When to Avoid It:
Small models where memory isn't the bottleneck
When training speed is critical and memory is plentiful
During inference (it only helps training)
Optimal Settings:
Checkpoint every 2-4 layers for transformers
Use with mixed precision for maximum memory savings
Increase batch size to compensate for speed loss
𒅃 The universal Turing machine had infinite memory for its computations. We've discovered that finite, strategically managed memory can actually make our artificial minds more efficient. Sometimes the best way to remember everything is to forget most of it – and trust that you can figure it out again when needed.
Next up: ZeRO Optimizer States – where we learn that even the optimizer's memory can be distributed, sharded, and generally made someone else's problem.
Questions about activation memory? Horror stories about OOM errors? Drop them in the comments. Nothing builds solidarity like shared CUDA out-of-memory trauma at 2 AM. (Yes, I got a whole extra hour of sleep that night.)

