Pipeline Parallelism: Forward/Backward Passes Across Layers
Part 2 of Distributed Training from Scratch
In the previous post, we carved up attention heads like a distributed sushi chef in a Murakami story, splitting individual operations across GPUs, praying our gradients would synchronize properly, and like a Murakami protagonist, hoping their disparate worlds would somehow align in the end. Today, we're going vertical. Instead of splitting computations within layers, we're going to split the layers themselves across devices, creating a computational assembly line that would make Henry Ford weep with joy. Or at at least his digital twin.
Welcome to pipeline parallelism, where your x billion parameter transformer becomes a factory floor with each GPU responsible for a different stage of production. GPU 0 handles layers 1-8, GPU 1 takes layers 9-16, and so on. Simple, elegant, and absolutely riddled with the kind of subtle timing dependencies that will have you debugging phantom deadlocks at 3 AM while questioning your life choices. (I’m literally speaking from experience.)
The beautiful irony is that pipeline parallelism is simultaneously the most intuitive form of model parallelism (obviously you put different layers on different devices!) and the most diabolically complex to get right. It's like explaining how a car engine works versus actually building one that doesn't explode on the first try. (That joke will probably be lost on the next generation, since motors generally don’t explode.)
The Assembly Line: Layers as Pipeline Stages
The core insight behind pipeline parallelism is deceptively simple. Your transformer has, say, 80 layers. You have 8 GPUs. Put 10 layers on each GPU, and boom, you've got yourself a pipeline. Data flows from GPU 0 through GPU 7, each device applying its layers in sequence.
import torch
import torch.nn as nn
from typing import List, Optional
class PipelineStage(nn.Module):
"""A single stage in our pipeline (multiple transformer layers)"""
def __init__(self, layers: List[nn.Module], stage_id: int):
super().__init__()
self.layers = nn.ModuleList(layers)
self.stage_id = stage_id
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class NaivePipelineParallelTransformer(nn.Module):
"""The simplest possible pipeline parallel transformer"""
def __init__(self, total_layers=80, num_stages=8, d_model=512):
super().__init__()
self.num_stages = num_stages
self.layers_per_stage = total_layers // num_stages
# Create all transformer layers
all_layers = [TransformerLayer(d_model) for _ in range(total_layers)]
# Split layers across stages
self.stages = nn.ModuleList([
PipelineStage(
all_layers[i * self.layers_per_stage:(i + 1) * self.layers_per_stage],
stage_id=i
) for i in range(num_stages)
])
def forward(self, x):
# Sequential execution (this is wrong, but let's start here)
for stage in self.stages:
x = stage(x)
return x
This naive implementation misses the entire point. We're still running everything sequentially on a single device. The magic happens when we actually distribute these stages across GPUs and start overlapping computation.
The Bubble Problem: Why Pipeline Parallelism is Hard
Here's where pipeline parallelism gets interesting in all the wrong ways. Unlike data parallelism (where all GPUs work independently) or tensor parallelism (where GPUs collaborate on the same computation), pipeline parallelism creates dependencies. GPU 1 can't start working until GPU 0 finishes. GPU 2 waits for GPU 1. And so on.
This creates what we lovingly call "pipeline bubbles" which are periods where GPUs sit idle waiting for data from upstream stages. In the worst case, only one GPU is working at a time, giving you 1/8th the performance of a single device. Not exactly the scaling we were hoping for.
The solution is micro-batching: instead of sending one large batch through the pipeline, we split it into smaller micro-batches that can overlap in different stages.
class MicroBatch:
"""A single micro-batch flowing through the pipeline"""
def __init__(self, data: torch.Tensor, micro_batch_id: int):
self.data = data
self.micro_batch_id = micro_batch_id
self.stage_outputs = {} # Cache outputs at each stage
def __repr__(self):
return f"MicroBatch(id={self.micro_batch_id}, shape={self.data.shape})"
class PipelineScheduler:
"""Manages the flow of micro-batches through pipeline stages"""
def __init__(self, num_stages: int, num_micro_batches: int):
self.num_stages = num_stages
self.num_micro_batches = num_micro_batches
# Each stage maintains a queue of micro-batches
self.stage_queues = [[] for _ in range(num_stages)]
self.completed_forward = []
self.completed_backward = []
def create_schedule(self):
"""Generate the optimal schedule for forward/backward passes"""
schedule = []
# Forward passes: stagger micro-batches across stages
for step in range(self.num_stages + self.num_micro_batches - 1):
for stage_id in range(self.num_stages):
micro_batch_id = step - stage_id
if 0 <= micro_batch_id < self.num_micro_batches:
schedule.append(('forward', stage_id, micro_batch_id))
# Backward passes: reverse order
for step in range(self.num_stages + self.num_micro_batches - 1):
for stage_id in range(self.num_stages - 1, -1, -1):
micro_batch_id = step - (self.num_stages - 1 - stage_id)
if 0 <= micro_batch_id < self.num_micro_batches:
schedule.append(('backward', stage_id, micro_batch_id))
return schedule
GPipe vs PipeDream: The Scheduling Wars
There are two main approaches to pipeline scheduling, and choosing between them is like choosing between different types of controlled chaos.
GPipe (Google's approach): Complete all forward passes before starting backward passes. Simple, synchronous, but creates large pipeline bubbles. Think of it as a batch processing system.
PipeDream (Microsoft's approach): Interleave forward and backward passes to minimize bubbles. More complex, but much better hardware utilization. Think of it as a streaming system.
Those of you who know me, know I have a lot to say about batching vs streaming, as if it was a profound philosophical choice and not merely an architectural decision. Here's a simplified PipeDream implementation:
import torch.distributed as dist
from collections import deque
import asyncio
class PipeDreamStage:
"""A single stage in the PipeDream pipeline"""
def __init__(self, model: nn.Module, stage_id: int, num_stages: int):
self.model = model.cuda()
self.stage_id = stage_id
self.num_stages = num_stages
self.rank = dist.get_rank()
# Queues for managing micro-batches
self.forward_queue = deque()
self.backward_queue = deque()
# Version management for weight updates
self.version = 0
self.weight_versions = {}
async def forward_pass(self, micro_batch: MicroBatch):
"""Execute forward pass for a micro-batch"""
# Receive input from previous stage (except first stage)
if self.stage_id > 0:
input_tensor = await self.receive_activation(micro_batch.micro_batch_id)
micro_batch.data = input_tensor
# Store current weights version for this micro-batch
self.weight_versions[micro_batch.micro_batch_id] = self.version
# Forward computation
with torch.no_grad():
output = self.model(micro_batch.data)
# Send output to next stage (except last stage)
if self.stage_id < self.num_stages - 1:
await self.send_activation(output, micro_batch.micro_batch_id)
# Store for backward pass
micro_batch.stage_outputs[self.stage_id] = output
self.backward_queue.append(micro_batch)
return output
async def backward_pass(self, micro_batch: MicroBatch):
"""Execute backward pass for a micro-batch"""
# Receive gradient from next stage (except last stage)
if self.stage_id < self.num_stages - 1:
grad_output = await self.receive_gradient(micro_batch.micro_batch_id)
else:
# Last stage computes loss
grad_output = self.compute_loss_gradient(micro_batch)
# Restore weights to the version used in forward pass
forward_version = self.weight_versions[micro_batch.micro_batch_id]
if forward_version != self.version:
await self.restore_weights(forward_version)
# Backward computation
output = micro_batch.stage_outputs[self.stage_id]
output.backward(grad_output)
# Send gradient to previous stage (except first stage)
if self.stage_id > 0:
input_grad = micro_batch.data.grad
await self.send_gradient(input_grad, micro_batch.micro_batch_id)
# Clean up
del self.weight_versions[micro_batch.micro_batch_id]
del micro_batch.stage_outputs[self.stage_id]
async def send_activation(self, tensor: torch.Tensor, micro_batch_id: int):
"""Send activation to next stage"""
next_rank = (self.rank + 1) % dist.get_world_size()
# In real implementation, use dist.isend with proper serialization
dist.send(tensor.contiguous(), dst=next_rank, tag=micro_batch_id)
async def receive_activation(self, micro_batch_id: int):
"""Receive activation from previous stage"""
prev_rank = (self.rank - 1) % dist.get_world_size()
# In real implementation, use dist.irecv with proper deserialization
tensor = torch.empty_like(self.expected_input_shape)
dist.recv(tensor, src=prev_rank, tag=micro_batch_id)
return tensor
async def send_gradient(self, tensor: torch.Tensor, micro_batch_id: int):
"""Send gradient to previous stage"""
prev_rank = (self.rank - 1) % dist.get_world_size()
dist.send(tensor.contiguous(), dst=prev_rank, tag=micro_batch_id + 10000)
async def receive_gradient(self, micro_batch_id: int):
"""Receive gradient from next stage"""
next_rank = (self.rank + 1) % dist.get_world_size()
tensor = torch.empty_like(self.expected_output_shape)
dist.recv(tensor, src=next_rank, tag=micro_batch_id + 10000)
return tensor
This simplified PipeDreamStage code provides a high-level conceptual view of how interleaved forward and backward passes work. However, in a production environment, this approach has more hidden complexity than a smile in a Dostoyevsky novel. The async keywords are a hint at the real challenge: coordinating these concurrent tasks and managing weight versions without leading to deadlocks or incorrect gradients. While torch.distributed.isend and irecv provide non-blocking communication, using them directly in a naive way still requires manual management of the requests and handles, leading to complex race conditions and deadlocks. And you can't simply await them in a standard Python asyncio loop since torch.distributed operations are blocking by default.. The solution lies in a more robust scheduling and communication framework that handles these details for us, which we will explore next.
Memory Management: The Hidden Complexity
Pipeline parallelism creates a subtle but critical memory management problem. During forward passes, you need to store activations for the backward pass. But with micro-batching, you might have multiple forward passes in flight before the first backward pass starts.
This means your memory usage grows linearly with the number of micro-batches. Too few micro-batches and you get pipeline bubbles. Too many and you run out of memory. Welcome to the optimization nightmare.
class ActivationCheckpointing:
"""Manage activation storage with checkpointing to save memory"""
def __init__(self, checkpoint_every_n_layers: int = 4):
self.checkpoint_every_n_layers = checkpoint_every_n_layers
self.activation_cache = {}
self.memory_pool = []
def should_checkpoint(self, layer_id: int) -> bool:
"""Decide whether to checkpoint this layer's activations"""
return layer_id % self.checkpoint_every_n_layers == 0
def store_activation(self, layer_id: int, micro_batch_id: int, activation: torch.Tensor):
"""Store activation with optional checkpointing"""
key = (layer_id, micro_batch_id)
if self.should_checkpoint(layer_id):
# Store on CPU to save GPU memory
self.activation_cache[key] = activation.cpu()
else:
# Keep on GPU for faster access
self.activation_cache[key] = activation
def retrieve_activation(self, layer_id: int, micro_batch_id: int) -> torch.Tensor:
"""Retrieve activation, recomputing if necessary"""
key = (layer_id, micro_batch_id)
if key in self.activation_cache:
activation = self.activation_cache[key]
if activation.device.type == 'cpu':
activation = activation.cuda()
del self.activation_cache[key] # Free memory immediately
return activation
else:
# Recompute from last checkpoint
return self.recompute_activation(layer_id, micro_batch_id)
def recompute_activation(self, layer_id: int, micro_batch_id: int) -> torch.Tensor:
"""Recompute activation from the last checkpoint"""
# Find the last checkpointed layer
checkpoint_layer = (layer_id // self.checkpoint_every_n_layers) * self.checkpoint_every_n_layers
# Retrieve checkpoint and recompute forward
checkpoint = self.retrieve_activation(checkpoint_layer, micro_batch_id)
# Forward through layers until we reach target layer
x = checkpoint
for i in range(checkpoint_layer + 1, layer_id + 1):
x = self.model.layers[i](x)
return x
The Communication Pattern: Point-to-Point Chaos
Unlike the all-reduce patterns in data and tensor parallelism, pipeline parallelism uses point-to-point communication. GPU 0 only talks to GPU 1, GPU 1 talks to GPU 0 and GPU 2, and so on. This creates a beautiful linear communication pattern that scales perfectly... until it doesn't.
The problem is that point-to-point communication is much harder to optimize than collective operations. NCCL is heavily optimized for all-reduce, but peer-to-peer transfers can become bottlenecks, especially with high-bandwidth interconnects where the latency matters more than bandwidth.
class OptimizedP2PCommunication:
"""Optimized point-to-point communication for pipeline parallelism"""
def __init__(self, rank: int, world_size: int):
self.rank = rank
self.world_size = world_size
# Pre-allocate communication buffers
self.send_buffers = {}
self.recv_buffers = {}
# Track in-flight operations
self.pending_sends = {}
self.pending_recvs = {}
def async_send(self, tensor: torch.Tensor, dst: int, tag: int):
"""Non-blocking send with buffer management"""
# Copy to pinned memory for faster transfer
if tag not in self.send_buffers:
self.send_buffers[tag] = torch.empty_like(tensor, pin_memory=True)
send_buffer = self.send_buffers[tag]
send_buffer.copy_(tensor, non_blocking=True)
# Start async send
req = dist.isend(send_buffer, dst=dst, tag=tag)
self.pending_sends[tag] = req
return req
def async_recv(self, shape: tuple, dtype: torch.dtype, src: int, tag: int):
"""Non-blocking receive with buffer management"""
# Use pre-allocated pinned buffer
if tag not in self.recv_buffers:
self.recv_buffers[tag] = torch.empty(
shape, dtype=dtype, pin_memory=True, device='cuda'
)
recv_buffer = self.recv_buffers[tag]
# Start async receive
req = dist.irecv(recv_buffer, src=src, tag=tag)
self.pending_recvs[tag] = req
return req, recv_buffer
def wait_send(self, tag: int):
"""Wait for send to complete"""
if tag in self.pending_sends:
self.pending_sends[tag].wait()
del self.pending_sends[tag]
def wait_recv(self, tag: int):
"""Wait for receive to complete"""
if tag in self.pending_recvs:
self.pending_recvs[tag].wait()
del self.pending_recvs[tag]
return self.recv_buffers[tag]
Load Balancing: When Layers Aren't Created Equal
Here's a dirty secret about transformer layers: they're not all the same computational cost. Early layers often do more heavy lifting, attention patterns vary with depth, and some layers might have more parameters than others. If you naively split 80 layers into 8 groups of 10, you might end up with massive load imbalancing.
import time
from typing import Dict, List
class LayerProfiler:
"""Profile computational cost of individual layers"""
def __init__(self, model: nn.Module, sample_input: torch.Tensor):
self.model = model
self.sample_input = sample_input
self.layer_costs = {}
def profile_layers(self, num_runs: int = 100) -> Dict[int, float]:
"""Profile the computational cost of each layer"""
with torch.no_grad():
x = self.sample_input
for layer_id, layer in enumerate(self.model.layers):
# Warmup
for _ in range(10):
_ = layer(x)
torch.cuda.synchronize()
# Actual timing
start_time = time.perf_counter()
for _ in range(num_runs):
output = layer(x)
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs
self.layer_costs[layer_id] = avg_time
# Use output as input for next layer
x = output
return self.layer_costs
def optimal_partitioning(self, num_stages: int) -> List[List[int]]:
"""Find optimal layer partitioning based on computational cost"""
if not self.layer_costs:
self.profile_layers()
# Sort layers by cost (greedy approximation)
sorted_layers = sorted(self.layer_costs.items(),
key=lambda x: x[1], reverse=True)
# Initialize stages
stages = [[] for _ in range(num_stages)]
stage_costs = [0.0] * num_stages
# Greedy assignment to least loaded stage
for layer_id, cost in sorted_layers:
min_stage = min(range(num_stages), key=lambda i: stage_costs[i])
stages[min_stage].append(layer_id)
stage_costs[min_stage] += cost
# Sort layers within each stage to maintain model order
for stage in stages:
stage.sort()
print("Stage costs:", stage_costs)
print("Load balance ratio:", max(stage_costs) / min(stage_costs))
return stages
Putting It All Together: A Real Pipeline
Here's a complete pipeline parallelism implementation that actually works:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from dataclasses import dataclass
from typing import Optional, List
import asyncio
import logging
@dataclass
class PipelineConfig:
num_stages: int = 4
num_micro_batches: int = 8
checkpoint_every_n_layers: int = 4
use_async_communication: bool = True
profile_layers: bool = True
class ProductionPipelineStage:
"""Production-ready pipeline stage with all optimizations"""
def __init__(self,
layers: List[nn.Module],
stage_id: int,
config: PipelineConfig):
self.layers = nn.Sequential(*layers).cuda()
self.stage_id = stage_id
self.config = config
self.rank = dist.get_rank()
# Communication optimizations
self.comm = OptimizedP2PCommunication(self.rank, dist.get_world_size())
# Memory optimizations
self.activation_checkpointer = ActivationCheckpointing(
config.checkpoint_every_n_layers
)
# Profiling
self.forward_times = []
self.backward_times = []
self.communication_times = []
def forward_micro_batch(self, micro_batch_id: int, input_tensor: Optional[torch.Tensor] = None):
"""Process one micro-batch through this stage"""
start_time = time.perf_counter()
# Receive input from previous stage
if self.stage_id > 0 and input_tensor is None:
comm_start = time.perf_counter()
req, recv_buffer = self.comm.async_recv(
shape=self.expected_input_shape,
dtype=torch.float16,
src=self.rank - 1,
tag=micro_batch_id
)
self.comm.wait_recv(micro_batch_id)
input_tensor = recv_buffer.clone()
self.communication_times.append(time.perf_counter() - comm_start)
# Forward pass
output = self.layers(input_tensor)
# Send to next stage
if self.stage_id < self.config.num_stages - 1:
comm_start = time.perf_counter()
self.comm.async_send(output, dst=self.rank + 1, tag=micro_batch_id)
self.communication_times.append(time.perf_counter() - comm_start)
# Store for backward pass
self.activation_checkpointer.store_activation(
self.stage_id, micro_batch_id, input_tensor
)
forward_time = time.perf_counter() - start_time
self.forward_times.append(forward_time)
return output
def backward_micro_batch(self, micro_batch_id: int, grad_output: Optional[torch.Tensor] = None):
"""Backward pass for one micro-batch"""
start_time = time.perf_counter()
# Receive gradient from next stage
if self.stage_id < self.config.num_stages - 1 and grad_output is None:
req, recv_buffer = self.comm.async_recv(
shape=self.expected_output_shape,
dtype=torch.float16,
src=self.rank + 1,
tag=micro_batch_id + 10000 # Offset to avoid conflicts
)
self.comm.wait_recv(micro_batch_id + 10000)
grad_output = recv_buffer.clone()
# Retrieve stored activation
input_activation = self.activation_checkpointer.retrieve_activation(
self.stage_id, micro_batch_id
)
# Backward pass
input_activation.requires_grad_(True)
output = self.layers(input_activation)
output.backward(grad_output)
# Send gradient to previous stage
if self.stage_id > 0:
grad_input = input_activation.grad
self.comm.async_send(grad_input, dst=self.rank - 1, tag=micro_batch_id + 10000)
backward_time = time.perf_counter() - start_time
self.backward_times.append(backward_time)
def run_pipeline_schedule(self, schedule: List[tuple]):
"""Execute the pipeline schedule"""
for operation, stage_id, micro_batch_id in schedule:
if stage_id == self.stage_id:
if operation == 'forward':
self.forward_micro_batch(micro_batch_id)
elif operation == 'backward':
self.backward_micro_batch(micro_batch_id)
# Wait for all communication to complete
for tag in list(self.comm.pending_sends.keys()):
self.comm.wait_send(tag)
def print_profiling_stats(self):
"""Print performance statistics"""
if self.forward_times:
avg_forward = sum(self.forward_times) / len(self.forward_times)
avg_backward = sum(self.backward_times) / len(self.backward_times)
avg_comm = sum(self.communication_times) / len(self.communication_times)
print(f"Stage {self.stage_id} Performance:")
print(f" Average Forward: {avg_forward*1000:.2f}ms")
print(f" Average Backward: {avg_backward*1000:.2f}ms")
print(f" Average Comm: {avg_comm*1000:.2f}ms")
print(f" Bubble Time: {self.estimate_bubble_time():.1f}%")
def estimate_bubble_time(self) -> float:
"""Estimate percentage of time spent in pipeline bubbles"""
if not self.forward_times:
return 0.0
total_computation = sum(self.forward_times) + sum(self.backward_times)
total_communication = sum(self.communication_times)
# Rough estimate: bubbles = idle time / total time
bubble_ratio = total_communication / (total_computation + total_communication)
return bubble_ratio * 100
# Usage example
def train_with_pipeline_parallelism(rank, world_size):
"""Main training function with pipeline parallelism"""
# Setup distributed training
setup_distributed(rank, world_size)
# Create pipeline configuration
config = PipelineConfig(
num_stages=world_size,
num_micro_batches=16,
checkpoint_every_n_layers=2
)
# Create transformer layers (simplified)
total_layers = 32
layers_per_stage = total_layers // world_size
my_layers = []
for i in range(layers_per_stage):
layer_id = rank * layers_per_stage + i
my_layers.append(TransformerLayer(d_model=512, layer_id=layer_id))
# Create pipeline stage
stage = ProductionPipelineStage(my_layers, rank, config)
# Create schedule
scheduler = PipelineScheduler(world_size, config.num_micro_batches)
schedule = scheduler.create_schedule()
print(f"Rank {rank}: Running {len(schedule)} pipeline operations")
# Training loop
for epoch in range(10):
stage.run_pipeline_schedule(schedule)
if rank == 0:
print(f"Completed epoch {epoch}")
# Print performance stats
stage.print_profiling_stats()
cleanup_distributed()
if __name__ == "__main__":
world_size = 4
mp.spawn(train_with_pipeline_parallelism, args=(world_size,), nprocs=world_size)
As should be obvi, a production-ready pipeline parallel implementation is more complex than Nijinsky tryna choreograph Stravinsky while the orchestra is tuning.
It combines strategic model partitioning with sophisticated scheduling, memory management, and asynchronous communication. Each component, from the micro-batch scheduler to the activation checkpointer, is designed to chip away at the fundamental inefficiencies (pipeline bubbles and memory bottlenecks) that plague a naive approach. This architecture represents the culmination of a decade of research into distributed training, transforming a simple idea into a system that can train the world's largest AI models. And now you have the power.
The Performance Reality Check
Ah but the power of pipeline parallelism is a double-edged sword. When it works well, you get excellent memory efficiency and can train models that would never fit on a single device. When it works poorly, you get pipeline bubbles that make your expensive GPU cluster perform worse than a single machine.
The efficiency formula is roughly:
Pipeline Efficiency = (num_micro_batches) / (num_micro_batches + num_stages - 1)
(Cf. Huang et al, 2018, Narayanan et al, 2019)
With 8 stages and 16 micro-batches, you get about 70% efficiency. Not terrible, but not great either. The key is finding the sweet spot between memory usage and pipeline bubbles. We circle back on some advanced approaches later in the series.
The Memory vs Bubble Time Trade-off
Here's the fundamental tension in pipeline parallelism: more micro-batches means better pipeline utilization but higher memory usage. Fewer micro-batches means lower memory usage but worse utilization. There's no free lunch (contrary to what we tell our new hires.) Sorry about that.
def calculate_pipeline_efficiency(num_stages: int, num_micro_batches: int) -> dict:
"""Calculate various pipeline efficiency metrics"""
# Theoretical efficiency (ignoring communication)
theoretical_efficiency = num_micro_batches / (num_micro_batches + num_stages - 1)
# Memory scaling (linear with micro-batches)
memory_multiplier = num_micro_batches
# Bubble time (time when not all stages are active)
bubble_steps = num_stages - 1
total_steps = num_micro_batches + num_stages - 1
bubble_percentage = (bubble_steps / total_steps) * 100
return {
'efficiency': theoretical_efficiency,
'memory_multiplier': memory_multiplier,
'bubble_percentage': bubble_percentage,
'throughput_ratio': theoretical_efficiency,
}
# Find optimal micro-batch count
for num_micro_batches in [4, 8, 16, 32]:
metrics = calculate_pipeline_efficiency(8, num_micro_batches)
print(f"Micro-batches: {num_micro_batches}")
print(f" Efficiency: {metrics['efficiency']:.1%}")
print(f" Memory: {metrics['memory_multiplier']}x")
print(f" Bubble time: {metrics['bubble_percentage']:.1f}%")
print()
What's Next: The Hybrid Future
Pure pipeline parallelism is rarely the answer for modern large models. The real magic happens when you combine it with tensor parallelism and data parallelism in clever ways. Imagine splitting your model both horizontally (tensor parallelism within layers) and vertically (pipeline parallelism across layers), then replicating the whole thing across multiple data parallel groups. This is essentially data layout coherence for AI models, viz. maintaining alignment between your data organization and access patterns across the entire stack, from storage through compute kernels. Later we’ll explore how data type enforcement across the entire stack intersections with data layout coherence. The key insight is designing sharding strategies where horizontal and vertical data flows naturally align, creating end-to-end data locality optimization that maximizes bandwidth utilization efficiency at every level of the compute hierarchy.
But that's a story for another post (the last one in this unit) where we'll dive into 3D parallelism and learn how to orchestrate tensor, pipeline, and data parallelism simultaneously without losing our sanity.
For now, you have the tools to build production-ready pipeline parallel training. The code above will actually run and give you meaningful performance insights. Just don't blame me when you're debugging communication deadlocks at 3 AM while your GPUs burn electricity doing nothing. I’ll be busy stacking NVDA.
𒅃 The universal Turing machine processed one symbol at a time on an infinite tape. We've decided that's too slow, so we've built an assembly line of Turing machines, each processing different parts of the computation in parallel. Progress, or madness? It’s a real question. And honestly at this point, who is left to disagree that it’s probably both?
Next up: Data Parallelism and Gradient Synchronization - Where we learn to love AllReduce as much as we hate communication bottlenecks.
Questions about pipeline parallelism? Ready to share your own micro-batching horror stories? Paid subscribers drop them in the comments. Nothing builds engineering camaraderie like shared suffering over distributed training bugs. (Subscription is free, paid level has perks.)

