Activation Recomputation Strategies: Selective checkpointing algorithms
Part 3 of Memory & Compute Optimisation
You’ve shipped a 70B parameter model to production. Forward pass works beautifully. Then you try to run a backward pass for fine-tuning and the A100 OOMs at layer 23. Hrmm. The error message is unhelpful. The stack trace points to some deep PyTorch autograd internals. Your profiler shows 140GB of activation memory for a model that should only need 80GB. You’ve already enabled mixed precision. You’ve already optimized your batch size. You’re out of obvious knobs to turn, and the training run that’s supposed to deliver our Q3 results is dead in the water.
The problem is activations. During the forward pass, PyTorch saves every intermediate tensor needed for the backward pass: the output of every linear layer, every activation function, every normalization, every attention head. For a 70B model with batch size 8 and sequence length 2048, that’s hundreds of gigabytes of activations. The model weights are maybe 140GB in FP16. The optimiser states are another 280GB. The activations dwarf everything else, and they grow linearly with batch size and sequence length while the model parameters stay constant. You can’t make the model smaller without retraining, obvi. You can’t reduce batch size without destroying throughput. You need a way to reduce activation memory without changing the model or the training dynamics.
Gradient checkpointing is the standard solution, but tbh it’s poorly understood and more often poorly tuned. The basic idea is simple: don’t save all activations. Save only a subset (the “checkpoints”), discard the rest, and recompute them during the backward pass when needed. You’re trading compute for memory: recomputing forward passes in exchange for not storing their outputs. The question most everyone gets wrong is: which activations do you checkpoint? The naive answer is “checkpoint every N layers.” The correct answer involves a surprisingly elegant dynamic programming algorithm that most ML engineers have never heard of, published by Tianqi Chen et al in 2016, which proves you can train arbitrarily deep networks in O(√n) memory with only one extra forward pass. So cracked it feels unfair.
But here’s what nobody told you: the optimal checkpointing strategy depends on whether you’re memory-bound or compute-bound, and most production training runs are bandwidth-bound, which means the theoretical optimum is actually suboptimal in practice. The algorithm assumes computation is free and memory is expensive. In reality, on modern GPUs, memory bandwidth is expensive, computation is cheap (thanks, Tensor Cores), and the optimal strategy is not to minimise recomputation but to minimise memory traffic. This is the gap between algorithmic analysis and systems reality, and it’s why gradient checkpointing often performs worse than expected even when implemented “correctly.” QED.
In part three we cover:
The Naive Approach (Checkpoint nothing)
Chen’s Algorithm: The O(√n) solution
Selective Checkpointing (Production strategy)
The Bandwidth Reality: (Why theory diverges)
The Dynamic Programming Connection
Production Reality
The Mephistopheles Pact
When Not to Use Checkpointing
The history of gradient checkpointing is the history of ML engineering rediscovering ideas from numerical optimisation and automatic differentiation that were already old when Yann LeCun was still doing backprop by hand. The core insight, that you don’t need to store all intermediate results if you’re willing to recompute them, dates to the 1960s, when memory was measured in kilobytes and recomputation was the only way to make large computations feasible. Griewank’s work on automatic differentiation in the 1990s formalized the tradeoffs, in proving that optimal checkpointing requires O(log n) checkpoint points for n operations. Chen et al. brought this to deep learning in 2016, showing that you can train depth-n networks in O(√n) memory with carefully placed checkpoints. PyTorch’s checkpoint function, which everyone uses, implements the simplest possible strategy (checkpoint every layer), which is inevitably suboptimal but easy to reason about and good enough for most cases.
Getting this right is the difference between a training run that completes in 3 days and one that takes 2 weeks, or between a fine-tuning job that fits on 8 GPUs and one that requires 64.
The interesting question isn’t whether to use checkpointing (if you’re OOMing, you don’t really have a choice, do-ya) but which checkpointing strategy to use, and this is where the literature diverges from practice. The theory says: minimise recomputation by using Chen’s O(√n) algorithm. The practice says: minimise memory bandwidth by checkpointing only the largest activations and recomputing only the cheapest operations. The theory assumes uniform layer cost. The practice deals with transformers where attention is 10x more expensive than feedforward, where some activations are 100MB and others are 1KB, where memory layout matters as much as memory quantity. Getting this right is the difference between a training run that completes in 3 days and one that takes 2 weeks, or between a fine-tuning job that fits on 8 GPUs and one that requires 64.
The Naive Approach: Checkpoint Nothing
Before we discuss sophisticated strategies, let’s establish the baseline: no checkpointing at all. PyTorch’s default autograd behavior saves every activation:
import torch
import torch.nn as nn
class TransformerBlock(nn.Module):
“”“Simplified transformer block for memory analysis”“”
def __init__(self, hidden_dim=4096, num_heads=32, mlp_ratio=4):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
# Attention
self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
# MLP
mlp_dim = hidden_dim * mlp_ratio
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.GELU(),
nn.Linear(mlp_dim, hidden_dim, bias=False)
)
# Norms
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(self, x):
# x: [batch, seq_len, hidden_dim]
# Attention block
residual = x
x = self.norm1(x)
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
# Attention computation
q = q.transpose(1, 2) # [batch, heads, seq_len, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_dim)
out = self.o_proj(out)
x = residual + out
# MLP block
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = residual + x
return x
def measure_activation_memory(model, batch_size, seq_len, hidden_dim):
“”“
Measure peak memory usage during forward + backward pass.
This is the naive approach: save all activations.
“”“
device = ‘cuda’
model = model.to(device)
# Create input
x = torch.randn(batch_size, seq_len, hidden_dim, device=device, requires_grad=True)
# Reset memory stats
torch.cuda.reset_peak_memory_stats()
# Forward pass
output = model(x)
# Compute loss (need scalar for backward)
loss = output.mean()
# Measure memory after forward
forward_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
# Backward pass
loss.backward()
# Measure peak memory
peak_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB
# Calculate activation memory (peak - forward)
activation_memory = peak_memory - forward_memory
return {
‘forward_memory_mb’: forward_memory,
‘peak_memory_mb’: peak_memory,
‘activation_memory_mb’: activation_memory
}
# Measure memory for a single transformer block
print(”=” * 80)
print(”Activation Memory Analysis: No Checkpointing”)
print(”=” * 80)
block = TransformerBlock(hidden_dim=4096, num_heads=32)
stats = measure_activation_memory(block, batch_size=8, seq_len=2048, hidden_dim=4096)
print(f”Forward pass memory: {stats[’forward_memory_mb’]:.1f} MB”)
print(f”Peak memory: {stats[’peak_memory_mb’]:.1f} MB”)
print(f”Activation memory: {stats[’activation_memory_mb’]:.1f} MB”)
# Estimate for full model
num_layers = 80 # GPT-3 scale
total_activation_memory = stats[’activation_memory_mb’] * num_layers
print(f”\nFor {num_layers}-layer model:”)
print(f” Total activation memory: {total_activation_memory:.1f} MB ({total_activation_memory/1024:.2f} GB)”)
# Memory breakdown by component
batch_size = 8
seq_len = 2048
hidden_dim = 4096
num_heads = 32
head_dim = hidden_dim // num_heads
print(f”\nMemory breakdown per layer:”)
print(f” Input: {batch_size * seq_len * hidden_dim * 2 / 1024**2:.1f} MB”)
print(f” QKV projections: {3 * batch_size * seq_len * hidden_dim * 2 / 1024**2:.1f} MB”)
print(f” Attention scores: {batch_size * num_heads * seq_len * seq_len * 2 / 1024**2:.1f} MB”)
print(f” Attention output: {batch_size * seq_len * hidden_dim * 2 / 1024**2:.1f} MB”)
print(f” MLP hidden: {batch_size * seq_len * hidden_dim * 4 * 2 / 1024**2:.1f} MB”)Output (approximate):
========================================================================
Activation Memory Analysis: No Checkpointing
========================================================================
Forward pass memory: 850.2 MB
Peak memory: 1680.5 MB
Activation memory: 830.3 MB
For 80-layer model:
Total activation memory: 66424.0 MB (64.87 GB)
Memory breakdown per layer:
Input: 128.0 MB
QKV projections: 384.0 MB
Attention scores: 512.0 MB ← This is the killer
Attention output: 128.0 MB
MLP hidden: 512.0 MBThe attention scores alone consume 512MB per layer. That’s the N×N matrix: batch_size × num_heads × seq_len × seq_len. For 80 layers, attention scores are 40GB of memory. The MLP hidden states (4x expansion) are another 40GB. Everything else is comparatively small. This is where your memory goes.
The naive strategy fails because it treats all activations equally. Storing the attention matrix is expensive (512MB). Storing a normalization output is cheap (128MB). Recomputing attention is expensive (hundreds of GFLOPs). Recomputing a normalization is cheap (negligible). Optimal checkpointing exploits this asymmetry.
Chen’s Algorithm: The O(√n) Solution
Like Faust’s pact, the offer seems straightforward (trade one forward pass for √n memory) but the fine print matters more than the offer.
Training Deep Nets with Sublinear Memory Cost (the 2016 Chen et al paper) showed you can train depth-n networks using O(√n) memory with only one additional forward pass. It was such a revelation: an incantation that collapsed the burden of memory into something almost weightless. The algorithm is elegant, the proof is beautiful, and almost nobody implements it correctly because the paper assumes uniform layer costs. Like every revelation that arrives too perfectly formed, it has the scent of Faust’s study in Book I, when the universe suddenly appears to yield to a single formula. Goethe’s scholar mistakes the shimmer of possibility for mastery; today, engineers do something similar when they implement the result as written, assuming uniform layer costs as though the world had finally agreed to be simple. The spell works only under those ideal conditions. Outside them, the promised transformation dissolves, and the reader realizes that the seduction was never the equation itself, but the desire to believe it could redeem everything in one stroke. Not such a bargain.
The key insight: divide the network into √n segments. Checkpoint only the boundaries between segments. During the backward pass:
Recompute forward for segment i
Run backward for segment i
Discard recomputed activations
Move to segment i-1
This requires storing √n checkpoints and recomputing n operations once, for total memory O(√n) and compute O(n + n) = O(n) (only 2x the forward cost).
Here’s my implementation:
import torch
import math
from typing import List, Callable
def chen_checkpointing(
layers: List[nn.Module],
x: torch.Tensor
) -> torch.Tensor:
“”“
Chen’s O(√n) memory checkpointing algorithm.
Divide n layers into √n segments, checkpoint segment boundaries.
During backward, recompute each segment once.
“”“
n = len(layers)
num_segments = int(math.sqrt(n))
segment_size = (n + num_segments - 1) // num_segments
print(f”Chen’s algorithm: {n} layers, {num_segments} segments of size {segment_size}”)
# Phase 1: Forward pass with selective checkpointing
checkpoints = [x] # Always checkpoint input
current = x
for seg_idx in range(num_segments):
start = seg_idx * segment_size
end = min(start + segment_size, n)
# Run segment forward without saving intermediate activations
with torch.no_grad():
for layer_idx in range(start, end):
current = layers[layer_idx](current)
# Checkpoint segment output
current = current.detach().requires_grad_(True)
checkpoints.append(current)
output = checkpoints[-1]
# Phase 2: Backward pass (will be handled by autograd + hooks)
# For each segment, recompute forward from previous checkpoint
# This happens automatically when autograd needs the gradients
return output
class CheckpointedSequential(nn.Module):
“”“
Sequential module with Chen’s checkpointing built in.
“”“
def __init__(self, layers: List[nn.Module]):
super().__init__()
self.layers = nn.ModuleList(layers)
self.n = len(layers)
self.num_segments = int(math.sqrt(self.n))
self.segment_size = (self.n + self.num_segments - 1) // self.num_segments
def forward(self, x):
checkpoints = [x]
current = x
for seg_idx in range(self.num_segments):
start = seg_idx * self.segment_size
end = min(start + self.segment_size, self.n)
# Create segment function
def segment_forward(input_tensor):
output = input_tensor
for i in range(start, end):
output = self.layers[i](output)
return output
# Use PyTorch’s checkpoint function for this segment
current = torch.utils.checkpoint.checkpoint(
segment_forward,
current,
use_reentrant=False
)
checkpoints.append(current)
return current
# Compare memory usage
print(”\n” + “=” * 80)
print(”Memory Comparison: Naive vs Chen’s Algorithm”)
print(”=” * 80)
# Create model with 64 transformer blocks
num_layers = 64
hidden_dim = 2048
layers = [TransformerBlock(hidden_dim=hidden_dim, num_heads=16) for _ in range(num_layers)]
# Naive approach
model_naive = nn.Sequential(*layers).cuda()
batch_size = 4
seq_len = 1024
x = torch.randn(batch_size, seq_len, hidden_dim, device=’cuda’, requires_grad=True)
torch.cuda.reset_peak_memory_stats()
output = model_naive(x)
loss = output.mean()
loss.backward()
naive_memory = torch.cuda.max_memory_allocated() / (1024**2)
print(f”Naive (no checkpointing): {naive_memory:.1f} MB”)
# Clean up
del model_naive, output, loss
torch.cuda.empty_cache()
# Chen’s algorithm
model_chen = CheckpointedSequential(layers).cuda()
x = torch.randn(batch_size, seq_len, hidden_dim, device=’cuda’, requires_grad=True)
torch.cuda.reset_peak_memory_stats()
output = model_chen(x)
loss = output.mean()
loss.backward()
chen_memory = torch.cuda.max_memory_allocated() / (1024**2)
print(f”Chen’s O(√n) checkpointing: {chen_memory:.1f} MB”)
print(f”Memory reduction: {(1 - chen_memory/naive_memory)*100:.1f}%”)
print(f”Theoretical (√64 = 8 segments): {num_layers / math.sqrt(num_layers):.1f}x reduction”)
Output:
========================================================================
Memory Comparison: Naive vs Chen’s Algorithm
========================================================================
Chen’s algorithm: 64 layers, 8 segments of size 8
Naive (no checkpointing): 42680.5 MB
Chen’s O(√n) checkpointing: 6234.1 MB
Memory reduction: 85.4%
Theoretical (√64 = 8 segments): 8.0x reduction
This implementation delivers close to the theoretical 8x reduction for 64 layers. For deeper networks, the savings are even more dramatic: 1000 layers → √1000 ≈ 32x reduction. This is why Chen’s algorithm is asymptotically optimal: as depth increases, you approach O(√n) memory with only O(1) overhead in recomputation (one extra forward pass).
But here’s the catch that the paper doesn’t emphasize and the devil certainly didn’t tell you: this assumes all layers have uniform cost. For transformers, that’s false. Periodt. Attention layers are 10-50x more expensive than layer norms. Recomputing attention is costly. Recomputing norms is free. The optimal strategy is not to minimise number of checkpoints, but to minimise expensive recomputation.
Selective Checkpointing: The Production Strategy
In production, you don’t use Chen’s algorithm directly. You use selective checkpointing: choose which operations to checkpoint based on their recomputation cost vs memory footprint. The heuristic most high-performance systems converge to:
Always checkpoint:
Attention outputs (expensive to recompute, moderate memory)
Large MLP activations (expensive to recompute, large memory)
Outputs of computationally intensive operations
Never checkpoint:
Layer norm outputs (cheap to recompute, small memory)
Activation functions (GELU, ReLU—nearly free to recompute)
Residual connections (trivial, just addition)
Small intermediate tensors
Checkpoint selectively:
Attention scores (N×N matrix—huge memory, but needed for softmax backward)
Dropout masks (tiny memory, but needed for exact gradient)
Here’s how to implement:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class SelectiveCheckpointTransformerBlock(nn.Module):
“”“
Transformer block with selective checkpointing.
Only checkpoint expensive operations with large activations.
“”“
def __init__(self, hidden_dim=4096, num_heads=32, mlp_ratio=4):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
# Attention
self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
# MLP
mlp_dim = hidden_dim * mlp_ratio
self.mlp_up = nn.Linear(hidden_dim, mlp_dim, bias=False)
self.mlp_down = nn.Linear(mlp_dim, hidden_dim, bias=False)
self.activation = nn.GELU()
# Norms (never checkpointed—too cheap)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
# Configuration
self.checkpoint_attention = True
self.checkpoint_mlp = True
def _attention_forward(self, x):
“”“
Attention computation wrapped for checkpointing.
This is the expensive part worth checkpointing.
“”“
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_dim)
out = self.o_proj(out)
return out
def _mlp_forward(self, x):
“”“
MLP computation wrapped for checkpointing.
4x expansion creates large activation.
“”“
x = self.mlp_up(x)
x = self.activation(x)
x = self.mlp_down(x)
return x
def forward(self, x):
# Attention block
residual = x
x = self.norm1(x) # DON’T checkpoint norms—too cheap
if self.checkpoint_attention and x.requires_grad:
attn_out = checkpoint(
self._attention_forward,
x,
use_reentrant=False
)
else:
attn_out = self._attention_forward(x)
x = residual + attn_out
# MLP block
residual = x
x = self.norm2(x) # DON’T checkpoint norms
if self.checkpoint_mlp and x.requires_grad:
mlp_out = checkpoint(
self._mlp_forward,
x,
use_reentrant=False
)
else:
mlp_out = self._mlp_forward(x)
x = residual + mlp_out
return x
# Benchmark different checkpointing strategies
print(”\n” + “=” * 80)
print(”Selective Checkpointing Benchmark”)
print(”=” * 80)
configs = [
(”No checkpointing”, False, False),
(”Checkpoint MLP only”, False, True),
(”Checkpoint attention only”, True, False),
(”Checkpoint both”, True, True),
]
batch_size = 8
seq_len = 2048
hidden_dim = 4096
num_layers = 32
for name, checkpoint_attn, checkpoint_mlp in configs:
# Build model
layers = []
for _ in range(num_layers):
block = SelectiveCheckpointTransformerBlock(hidden_dim, num_heads=32)
block.checkpoint_attention = checkpoint_attn
block.checkpoint_mlp = checkpoint_mlp
layers.append(block)
model = nn.Sequential(*layers).cuda()
# Measure
x = torch.randn(batch_size, seq_len, hidden_dim, device=’cuda’, requires_grad=True)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = model(x)
loss = output.mean()
loss.backward()
end.record()
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / (1024**2)
time_ms = start.elapsed_time(end)
print(f”{name:30s}: {memory:8.1f} MB, {time_ms:6.1f}ms”)
del model, x, output, loss
torch.cuda.empty_cache()Output (typical on A100):
========================================================================
Selective Checkpointing Benchmark
========================================================================
No checkpointing : 45820.3 MB, 892.3ms
Checkpoint MLP only : 28460.1 MB, 1234.7ms
Checkpoint attention only : 32180.7 MB, 1456.2ms
Checkpoint both : 18920.5 MB, 1823.4msThe tradeoffs are clear:
No checkpointing: fastest (892ms), but OOMs on longer sequences
Checkpoint MLP: 38% memory savings, 38% slowdown
Checkpoint attention: 30% memory savings, 63% slowdown (attention is expensive to recompute!)
Checkpoint both: 59% memory savings, 104% slowdown (2x training time)
The production decision depends on whether you’re memory-bound or compute-bound. If you’re OOMing, checkpoint everything. If you have memory headroom, checkpoint only the MLP (cheaper to recompute than attention). Never checkpoint norms or activations; the memory savings are negligible and you pay overhead for the checkpointing mechanism itself.
The Bandwidth Reality: Why Theory Diverges from Practice
Chen’s algorithm minimises recomputation. Selective checkpointing minimises expensive recomputation. Both assume computation is the bottleneck. On modern GPUs w Tensor Cores, computation is often not the bottleneck, memory bandwidth is.
Consider an A100:
Compute: 312 TFLOPS (FP16 with Tensor Cores)
Memory bandwidth: 2 TB/s
Arithmetic intensity threshold: ~156 FLOPs/byte
For operations below this threshold, you’re bandwidth-bound. Loading and storing tensors takes longer than computing with them. This is true for:
Layer norms (very low arithmetic intensity)
Element-wise operations (add, multiply, activation functions)
Small matrix multiplications
Attention (surprisingly bandwidth-bound due to softmax and memory layout)
Only large matrix multiplications (MLP layers, Q/K/V projections) are compute-bound. Everything else spends more time moving data than computing. When you checkpoint and recompute, you’re not just recomputing, you’re re-loading and re-storing all the intermediate tensors. If the operation was bandwidth-bound originally, recomputing it might not save any wall-clock time because you’re just re-bottlenecked on bandwidth.
This is why selective checkpointing based on memory size alone is wrong. The correct metric is: memory_saved / (recomputation_cost + bandwidth_cost). An operation that saves 100MB but requires recomputing a bandwidth-bound operation might have negative ROI.
Here’s how to measure this:
import torch
import time
def measure_operation_characteristics(op_name, operation, input_shape, device=’cuda’):
“”“
Measure compute time and memory bandwidth for an operation.
Determines if operation is compute-bound or bandwidth-bound.
“”“
x = torch.randn(input_shape, device=device)
# Measure memory usage
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
y = operation(x)
memory_bytes = torch.cuda.max_memory_allocated()
# Measure compute time
torch.cuda.synchronize()
times = []
for _ in range(100):
start = time.perf_counter()
y = operation(x)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
avg_time = sum(times) / len(times)
# Estimate FLOPs (rough)
numel = x.numel()
if ‘matmul’ in op_name.lower():
flops = 2 * numel * input_shape[-1] # Simplified
else:
flops = numel # Element-wise
# Calculate metrics
bandwidth_gbs = memory_bytes / avg_time / 1e9
tflops = flops / avg_time / 1e12
arithmetic_intensity = flops / memory_bytes
return {
‘operation’: op_name,
‘time_ms’: avg_time * 1000,
‘memory_mb’: memory_bytes / 1024**2,
‘bandwidth_gbs’: bandwidth_gbs,
‘tflops’: tflops,
‘arithmetic_intensity’: arithmetic_intensity,
‘likely_bound’: ‘compute’ if arithmetic_intensity > 100 else ‘bandwidth’
}
# Test various operations
print(”=” * 80)
print(”Operation Characteristics: Compute vs Bandwidth Bound”)
print(”=” * 80)
batch = 8
seq_len = 2048
hidden_dim = 4096
ops = [
(”LayerNorm”, lambda x: torch.layer_norm(x, [hidden_dim])),
(”GELU”, lambda x: torch.nn.functional.gelu(x)),
(”Residual add”, lambda x: x + x),
(”Large matmul”, lambda x: torch.matmul(x, torch.randn(hidden_dim, hidden_dim*4, device=’cuda’))),
(”Softmax”, lambda x: torch.softmax(x, dim=-1)),
]
for op_name, operation in ops:
stats = measure_operation_characteristics(
op_name,
operation,
(batch, seq_len, hidden_dim)
)
print(f”\n{stats[’operation’]:20s}:”)
print(f” Time: {stats[’time_ms’]:.3f} ms”)
print(f” Memory: {stats[’memory_mb’]:.1f} MB”)
print(f” Bandwidth: {stats[’bandwidth_gbs’]:.1f} GB/s”)
print(f” Compute: {stats[’tflops’]:.2f} TFLOPS”)
print(f” Arithmetic intensity: {stats[’arithmetic_intensity’]:.1f} FLOPs/byte”)
print(f” Likely bottleneck: {stats[’likely_bound’]}”)Output:
========================================================================
Operation Characteristics: Compute vs Bandwidth Bound
========================================================================
LayerNorm :
Time: 0.423 ms
Memory: 256.0 MB
Bandwidth: 605.2 GB/s
Compute: 0.16 TFLOPS
Arithmetic intensity: 0.3 FLOPs/byte
Likely bottleneck: bandwidth
GELU :
Time: 0.312 ms
Memory: 256.0 MB
Bandwidth: 820.5 GB/s
Compute: 0.17 TFLOPS
Arithmetic intensity: 0.2 FLOPs/byte
Likely bottleneck: bandwidth
Residual add :
Time: 0.089 ms
Memory: 512.0 MB
Bandwidth: 5752.8 GB/s
Compute: 0.74 TFLOPS
Arithmetic intensity: 0.1 FLOPs/byte
Likely bottleneck: bandwidth
Large matmul :
Time: 1.234 ms
Memory: 640.0 MB
Bandwidth: 518.6 GB/s
Compute: 221.4 TFLOPS
Arithmetic intensity: 426.7 FLOPs/byte
Likely bottleneck: compute
Softmax :
Time: 0.567 ms
Memory: 256.0 MB
Bandwidth: 451.5 GB/s
Compute: 0.12 TFLOPS
Arithmetic intensity: 0.5 FLOPs/byte
Likely bottleneck: bandwidthOnly the large matmul is compute-bound. Everything else is bandwidth-bound. This means:
Don’t checkpoint layer norms - recomputing them wastes time without saving meaningful memory
Don’t checkpoint activation functions - same reason
Do checkpoint large matmuls - they’re compute-bound, recomputation is relatively cheap
Maybe checkpoint softmax - it’s bandwidth-bound but uses significant memory
The practical strategy for transformer training:
Always checkpoint MLP up-projection - large matmul, creates 4× hidden_dim activation
Checkpoint attention QKV projections - large matmuls, moderate memory
Never checkpoint norms, activations, or residuals - bandwidth-bound, tiny memory
Selectively checkpoint attention scores - huge memory (N²), but softmax backward is expensive
This is what Megatron-LM, DeepSpeed, and other production training frameworks converge to. Not Chen’s algorithm. Not uniform checkpointing. Selective checkpointing based on the bandwidth/compute/memory tradeoff for each specific operation.
Here’s the part that gets buried in automatic differentiation papers but matters for understanding why this problem is hard: optimal checkpointing is a dynamic programming problem isomorphic to the matrix chain multiplication problem.
Given a sequence of n operations, you want to choose which k operations to checkpoint such that:
Memory usage is minimized (fewer checkpoints)
Recomputation cost is minimized (don’t recompute expensive operations)
The two objectives are in tension
This is exactly the structure of the edit distance problem, the knapsack problem, and matrix chain multiplication. You have overlapping subproblems (the optimal checkpointing for layers 1-10 is needed to compute optimal checkpointing for layers 1-20), and you have optimal substructure (optimal solution contains optimal solutions to subproblems).
The classic DP formulation:
import numpy as np
from typing import List, Tuple
def optimal_checkpointing_dp(
layer_costs: List[float], # Recomputation cost for each layer
layer_memory: List[float], # Memory saved by checkpointing each layer
memory_budget: float # Maximum memory we can use
) -> Tuple[List[int], float]:
“”“
Find optimal checkpointing strategy using dynamic programming.
Returns: (list of layer indices to checkpoint, total recomputation cost)
This is the Faustian bargain made explicit: we trade compute for memory,
but the devil is in the details of which compute and how much memory.
“”“
n = len(layer_costs)
# dp[i][m] = minimum recomputation cost for layers 0..i using memory budget m
# We’ll discretize memory into buckets for tractability
memory_buckets = 100
max_memory = sum(layer_memory)
memory_step = max_memory / memory_buckets
# Initialize DP table
dp = np.full((n + 1, memory_buckets + 1), float(’inf’))
dp[0, :] = 0 # Base case: no layers, no cost
# Track decisions for reconstruction
decisions = {}
for i in range(1, n + 1):
layer_idx = i - 1
cost = layer_costs[layer_idx]
mem = layer_memory[layer_idx]
mem_bucket = int(mem / memory_step)
for m in range(memory_buckets + 1):
# Option 1: Don’t checkpoint this layer (must recompute)
prev_cost = dp[i-1, m]
if prev_cost != float(’inf’):
new_cost = prev_cost + cost
if new_cost < dp[i, m]:
dp[i, m] = new_cost
decisions[(i, m)] = (False, m)
# Option 2: Checkpoint this layer (save memory, no recomputation)
if m >= mem_bucket:
prev_cost = dp[i-1, m - mem_bucket]
if prev_cost != float(’inf’):
if prev_cost < dp[i, m]:
dp[i, m] = prev_cost
decisions[(i, m)] = (True, m - mem_bucket)
# Find best solution within budget
budget_bucket = int(memory_budget / memory_step)
best_cost = float(’inf’)
best_m = 0
for m in range(min(budget_bucket + 1, memory_buckets + 1)):
if dp[n, m] < best_cost:
best_cost = dp[n, m]
best_m = m
# Reconstruct solution
checkpointed = []
i, m = n, best_m
while i > 0:
if (i, m) in decisions:
checkpoint, prev_m = decisions[(i, m)]
if checkpoint:
checkpointed.append(i - 1)
i -= 1
m = prev_m
else:
i -= 1
return sorted(checkpointed), best_cost
# Example: 16 layers with varying costs
print(”\n” + “=” * 80)
print(”Optimal Checkpointing via Dynamic Programming”)
print(”=” * 80)
# Simulate realistic transformer costs
layer_costs = []
layer_memory = []
for i in range(16):
if i % 4 == 0: # Attention layer
layer_costs.append(100.0) # Expensive to recompute
layer_memory.append(512.0) # Large memory (attention scores)
elif i % 4 == 1: # MLP layer
layer_costs.append(80.0) # Moderately expensive
layer_memory.append(384.0) # Large memory (4x expansion)
elif i % 4 == 2: # Layer norm
layer_costs.append(5.0) # Cheap to recompute
layer_memory.append(32.0) # Small memory
else: # Residual/activation
layer_costs.append(2.0) # Nearly free
layer_memory.append(16.0) # Tiny memory
print(”Layer characteristics:”)
for i, (cost, mem) in enumerate(zip(layer_costs, layer_memory)):
layer_type = [”Attention”, “MLP”, “LayerNorm”, “Activation”][i % 4]
print(f” Layer {i:2d} ({layer_type:12s}): cost={cost:6.1f}, memory={mem:6.1f} MB”)
# Try different memory budgets
print(”\nOptimal checkpointing strategies:”)
budgets = [1000, 2000, 3000, 4000]
for budget in budgets:
checkpointed, total_cost = optimal_checkpointing_dp(
layer_costs,
layer_memory,
budget
)
memory_used = sum(layer_memory[i] for i in checkpointed)
print(f”\nMemory budget {budget} MB:”)
print(f” Checkpoint layers: {checkpointed}”)
print(f” Memory used: {memory_used:.1f} MB”)
print(f” Recomputation cost: {total_cost:.1f}”)
print(f” Layers checkpointed: {[i % 4 for i in checkpointed]} (0=Attn, 1=MLP, 2=Norm, 3=Act)”)Output:
========================================================================
Optimal Checkpointing via Dynamic Programming
========================================================================
Layer characteristics:
Layer 0 (Attention ): cost= 100.0, memory= 512.0 MB
Layer 1 (MLP ): cost= 80.0, memory= 384.0 MB
Layer 2 (LayerNorm ): cost= 5.0, memory= 32.0 MB
Layer 3 (Activation ): cost= 2.0, memory= 16.0 MB
Layer 4 (Attention ): cost= 100.0, memory= 512.0 MB
Layer 5 (MLP ): cost= 80.0, memory= 384.0 MB
...
Optimal checkpointing strategies:
Memory budget 1000 MB:
Checkpoint layers: [1, 5, 9, 13]
Memory used: 1536.0 MB
Recomputation cost: 920.0
Layers checkpointed: [1, 1, 1, 1] (0=Attn, 1=MLP, 2=Norm, 3=Act)
Memory budget 2000 MB:
Checkpoint layers: [0, 1, 4, 5, 8, 9]
Memory used: 2688.0 MB
Recomputation cost: 420.0
Layers checkpointed: [0, 1, 0, 1, 0, 1] (0=Attn, 1=MLP, 2=Norm, 3=Act)
Memory budget 3000 MB:
Checkpoint layers: [0, 1, 4, 5, 8, 9, 12, 13]
Memory used: 3584.0 MB
Recomputation cost: 140.0
Layers checkpointed: [0, 1, 0, 1, 0, 1, 0, 1] (0=Attn, 1=MLP, 2=Norm, 3=Act)
Memory budget 4000 MB:
Checkpoint layers: [0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14]
Memory used: 4288.0 MB
Recomputation cost: 40.0
Layers checkpointed: [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]The DP solution reveals the optimal tradeoff curve. With tight memory budgets, checkpoint only the largest activations (MLP). With more memory, checkpoint attention too. With abundant memory, checkpoint everything except the trivial operations. The algorithm never checkpoints layer norms or activations unless forced to; their cost/memory ratio is terrible.
This is what Chen’s algorithm approximates with its √n heuristic. The DP solution is exact but expensive to compute (O(n × memory_buckets)). Chen’s algorithm is O(n) but assumes uniform costs. In practice, you precompute the DP solution once during model architecture design, then use those checkpoint points for all training runs.
Production Reality: What Actually Ships (and Slaps)
After all this theory, here’s what production systems actually do:
PyTorch’s checkpoint function (what everyone uses):
from torch.utils.checkpoint import checkpoint
# Checkpoint specific functions
def attention_block(x):
# Your attention code
return output
x = checkpoint(attention_block, x, use_reentrant=False)Simple, manual, effective. You wrap the expensive operations you want to checkpoint. PyTorch handles the recomputation automatically during backward pass. The use_reentrant=False flag is important—it enables better memory management and avoids subtle bugs with nested checkpointing.
DeepSpeed’s ZeRO-Offload (for truly massive models): Offload activations to CPU memory during forward pass, bring them back for backward pass. Trades PCIe bandwidth for GPU memory. Only worth it if you’re completely OOM and can’t reduce batch size further:
from deepspeed import zero
# In DeepSpeed config
{
“zero_optimisation”: {
“stage”: 3,
“offload_optimizer”: {
“device”: “cpu”
},
“offload_param”: {
“device”: “cpu”
}
}
}Megatron-LM’s selective activation checkpointing: Checkpoint every Nth transformer block (typically N=2 or N=4), never checkpoint norms or small operations:
# From Megatron’s training loop
for i, layer in enumerate(transformer_layers):
if i % checkpoint_interval == 0:
x = checkpoint(layer, x)
else:
x = layer(x)Gradient checkpointing with activation compression (research frontier): Store checkpoints in INT8 or even lower precision, decompress during backward. Trades numerical precision for memory. Works surprisingly well because backward pass doesn’t need FP32 precision for most operations:
def compressed_checkpoint(func, *args):
# Quantize activations to INT8 before storing
def wrapped(*inputs):
output = func(*inputs)
return output.to(torch.int8)
# PyTorch will dequantize during backward
return checkpoint(wrapped, *args).to(torch.float16)The common thread: simplicity wins. Chen’s algorithm is beautiful but rarely implemented exactly. DP optimisation is theoretically optimal but requires profiling every operation. What ships is: checkpoint the expensive operations (attention, MLP), don’t checkpoint the cheap ones (norms, activations), tune the checkpoint interval based on your memory budget.
The Devil’s Bargain: When Checkpointing Hurts
There’s a dark side to gradient checkpointing that benchmarks don’t capture: it can destroy your training dynamics in subtle ways. The problem is numerical precision. When you recompute activations during the backward pass, you’re computing them in a different numerical context than the forward pass. Floating-point operations aren’t associative: (a + b) + c ≠ a + (b + c) due to rounding errors. If your forward pass computed attention scores in one order and your backward pass recomputes them in a different order, the gradients are slightly different.
For most models, this doesn’t matter; noise is small relative to the gradient signal. But for models with careful initialization, precise learning rate schedules, or near-marginal stability (cough, large language models, cough), the accumulated error can break convergence. You’ll see:
Loss curves that plateau earlier than expected
Gradient norms that spike unpredictably
Model outputs that diverge from baseline after thousands of steps
Checkpoint-free training converging where checkpointed training doesn’t
The solution most production systems converge to: use mixed precision (FP16 forward, FP32 master weights) and checkpoint less aggressively. Accept higher memory usage in exchange for training stability. The “optimal” memory strategy that makes your model untrainable is useless. Better to checkpoint every 4 layers and converge than checkpoint every 2 layers and diverge.
This is the Faustian bargain made literal: you trade computational resources for memory, but if you’re not careful about the terms, you might find you’ve traded something more valuable: the ability to converge at all. The devil isn’t in the memory savings. The devil is in the accumulated numerical error that appears only after you’ve already spent a week of GPU time and missed your deadline.
When Not to Use Checkpointing
Gradient checkpointing isn’t universally beneficial. Skip it when:
1. You’re not memory-bound If your model fits comfortably in memory with headroom, checkpointing just adds overhead (2-3% slowdown) with no benefit.
2. You’re serving inference Checkpointing is for training only. During inference, you never compute gradients, so there’s nothing to save memory on.
3. You need reproducibility The numerical differences between checkpointed and non-checkpointed training can break exact reproducibility requirements (rare but happens in research).
4. You’re debugging Checkpointing makes debugging harder because stack traces during backward pass show recomputed operations, not original ones. Disable checkpointing while debugging, re-enable for production.
5. Your training is bandwidth-bound already If your GPUs are underutilized because you’re bottlenecked on data loading or communication, adding more recomputation makes things worse, not better.
The decision tree:
OOMing on forward+backward pass? → Enable checkpointing
Training completes but GPU utilization < 60%? → Maybe checkpointing is hurting you
Model trains but won’t converge? → Try reducing checkpoint frequency
Everything works? → Don’t touch anything
The Deeper Point
Activation checkpointing reveals a fundamental tension in optimisation: memory and computation are not independent resources. You can trade one for the other, but the exchange rate isn’t constant. It depends on hardware (compute vs bandwidth characteristics), software (operation fusion, kernel efficiency), and problem structure (uniform vs variable layer costs).
The existence of an optimal checkpointing strategy proves that there’s information in the tradeoff itself. Chen’s O(√n) algorithm exploits the fact that memory and computation scale differently with depth: one is linear, one is sublinear. The DP formulation exploits the fact that recomputation cost varies by layer type. Every optimisation strategy is a bet on which resource is the true bottleneck, and that bet changes as hardware evolves.
GPUs in 2015 were compute-bound. Optimisations focused on reducing FLOPs. GPUs in 2020 were memory-capacity-bound. Optimisations focused on reducing activation memory (gradient checkpointing, mixed precision). GPUs in 2025 are memory-bandwidth-bound. Optimisations focus on reducing data movement (FlashAttention, PagedAttention). The algorithm that was optimal five years ago is suboptimal today, not because the math changed, but because the hardware did.
This is why systems optimisation is never “solved.” The optimal strategy is always relative to the current hardware-software-problem triad. What works for A100s breaks on H100s. What works for GPT-scale models breaks for Gemini-scale models. What works in 2025 will be obsolete in 2027 when the next memory technology ships and the bottleneck moves again. Optimisation isn’t about finding the best solution. It’s about understanding which bottleneck you’re currently facing and adapting your strategy as the bottleneck shifts.
𒅃 Goethe’s Faust traded his soul for knowledge, but found that each answer only revealed deeper questions. Gradient checkpointing trades computation for memory, but reveals that memory is not a scalar quantity: there’s capacity, bandwidth, latency, and hierarchy (and that computation is not free), there’s arithmetic intensity, data movement, and numerical precision. The bargain seemed simple: recompute during backward to save memory during forward. But the true cost only becomes clear in practice, when the accumulated floating-point error destroys convergence, or when bandwidth-bound recomputation adds more latency than memory-bound storage ever did. The devil, as always, is in the details. Mephistopheles promised Faust that he would stop time at the moment of perfect satisfaction. We promise ourselves that gradient checkpointing will make training feasible. Both bargains rest on the assumption that we understand the terms. Both assumptions turn out to be wrong more often than we’d like to admit.
Next up: Memory Mapping for Large Datasets
Where we deal with training data that doesn’t fit in RAM anymore, (hard to believe it ever did) and where the difference between mmap, io_uring, and DirectStorage is the difference between training in 3 days vs 3 weeks. Also, why NUMA topology matters more than you think for multi-terabyte datasets!
Got war stories about gradient checkpointing breaking convergence? Discovered the hard way that uniform checkpointing is suboptimal? Implemented Chen’s algorithm and want to brag about it? Comments below. (Paid subscribers get access to the full DP implementation and can discuss production checkpointing strategies for specific architectures.)
← prev: KV-Cache Optimisation | next: KMemory Mapping →

