KV-Cache Sharding & Distributed Inference
Part 8 of Memory & Compute Optimisation
By the end of Part 7, the model finally fits. Your checkpoint loads across multiple devices, layers are placed intentionally, and forward passes complete without immediate OOMs. The system runs. Then context length increases, batch size creeps up, and memory usage starts climbing again, this time from a different direction. Two steps forward, one step back.
Single-GPU KV-cache optimization is largely a solved problem. Techniques like PagedAttention eliminate fragmentation and FlashAttention dramatically reduces memory pressure. On paper, the numbers look great… until you push context length and model size at the same time. Ask for 128K context on a 70B model at batch size 32 and the math stops being your friend. Model weights alone take ~140 GB in FP16. The KV cache adds tens to hundreds of gigabytes depending on attention type and context length. No current GPU (or TPU) can hold that.
During inference, the fastest-growing memory object is the KV cache. It expands linearly with sequence length and batch size, is accessed on every token generation step, and is allocated wherever the attention layers execute. In a sharded setup, this often means that large portions of the context accumulate on a single device even while the model itself is distributed. Longer contexts amplify this imbalance until the same failure mode returns like the ghost of context past: memory exhaustion or unacceptable latency.
Sharding the KV cache follows naturally from sharding the model, but the cost profile changes. KV cache is not static state loaded once at startup; it is continuously read and written, and distributed access turns inference into a bandwidth-bound workload. Token generation becomes gated by inter-device memory movement, and throughput tracks network characteristics more closely than compute capability.
Once KV cache is sharded, interconnect bandwidth becomes the limiting factor.
High-bandwidth interconnects, attention variants like GQA, and newly emerging RDMA-capable consumer interconnects are turning KV-cache sharding from a theoretical exercise into something you can actually run yourself. Widespread adoption has exposed the real bottlenecks: network bandwidth, not FLOPs, is now the primary factor determining which inference workloads are feasible.
This critical shift is precisely where many distributed inference designs fail. If the process requires gigabytes of KV data to move for each token across a 25GB/s link, communication time will quickly dwarf computation time. Instead of achieving linear scaling, the result is a distributed memory system that is ultimately slower than a single GPU. Whether sharding helps or hurts depends entirely on how much data moves per token and how fast your hardware can move it. Understanding this communication constraint is the necessary prerequisite to designing truly scalable systems. In this the final part of Unit three we examine what this shift enables in practice, where the real bottlenecks are, which sharding strategies to use and how to move inference at scale out of the cloud and onto your desktop:
The Memory Problem at Scale
LLaMA 2 70B with 128K context:
num_layers = 80
num_kv_heads = 8 # GQA: 64 query heads, 8 KV heads
head_dim = 128
seq_len = 128 * 1024
batch_size = 32
kv_cache_per_layer = 2 * num_kv_heads * seq_len * head_dim * 2 # K,V,FP16
total_kv_cache = kv_cache_per_layer * num_layers * batch_size
print(f"Total KV cache: {total_kv_cache / 1024**3:.1f} GB") # 80.0 GB
80GB just for KV cache. Before model weights (140GB), activations (20-40GB). Single-GPU inference is impossible.
How GQA Made This Viable
Grouped Query Attention is more than a memory optimization and why distributed inference became practical. LLaMA 1 used Multi-Head Attention with 64 KV heads. LLaMA 2 uses GQA with 8 KV heads. For distributed inference at 128K context: MHA (64 KV heads): 76.8GB transfer per token GQA (8 KV heads): 9.6GB transfer per token. Nearly an order of magnitude difference.
(NB. These figures assume aggregate all-gather traffic across the cluster, not per-GPU local movement.)
That 8× reduction transforms the bandwidth equation. At 25GB/s (NVLink 2.0), MHA requires 3.1 seconds per token. GQA requires 384ms. This is the difference between unusable and production-viable. So GQA didn’t just save memory, it enabled our network math to actually work.
Sharding Strategies
There are four approaches, each with different communication patterns:
Shard by batch: Split sequences across GPUs. No communication during attention. Perfect for large batches, useless for batch_size=1. Every production system supports this because it’s free parallelism when workload permits.
Shard by head: Split KV heads across GPUs. Works with batch_size=1. Requires all-gather during attention, which is bandwidth-intensive but predictable. This is the baseline used in production inference systems at OpenAI, Anthropic, and similar labs. Natural extension of tensor parallelism from training.
Shard by sequence length (Ring Attention): Split context into chunks across GPUs. Each GPU computes attention for its chunk, passes results to next GPU in ring, receives from previous GPU. After N passes (N = number of GPUs), each GPU has full attention output. Enables contexts that don’t fit any other way, but requires custom attention kernels and careful synchronization. We’ll cover Ring Attention in Unit 5 when we discuss algorithmic innovations for million-token contexts.
Hybrid: Combine strategies to shard by batch when possible (eliminates communication) or fall back to head sharding for single requests. Complex topology management but flexible enough for mixed workloads.
Production systems use head sharding with speculative decoding (which we’ll cover in Unit 4) and custom NVLink topologies optimized for all-gather patterns. The fastest inference clusters aren’t the ones with the most GPUs; they’re the ones with the lowest-latency all-gather operations.
Head Sharding Implementation
import torch
import torch.distributed as dist
class ShardedKVCache:
"""KV cache sharded by attention head across GPUs."""
def __init__(self, num_layers, num_kv_heads, head_dim, max_seq_len,
max_batch_size, world_size, rank, device='cuda'):
self.world_size = world_size
self.rank = rank
assert num_kv_heads % world_size == 0
self.local_kv_heads = num_kv_heads // world_size
# Each GPU holds cache for subset of attention heads
self.k_cache = torch.zeros(
num_layers, max_batch_size, self.local_kv_heads,
max_seq_len, head_dim, dtype=torch.float16, device=device
)
self.v_cache = torch.zeros_like(self.k_cache)
self.seq_lens = torch.zeros(max_batch_size, dtype=torch.long, device=device)
def update(self, layer_idx, k_new, v_new, batch_indices=None):
"""Update cache with new tokens."""
if batch_indices is None:
batch_indices = torch.arange(k_new.shape[0], device=self.device)
positions = self.seq_lens[batch_indices]
for i, (batch_idx, pos) in enumerate(zip(batch_indices, positions)):
self.k_cache[layer_idx, batch_idx, :, pos] = k_new[i]
self.v_cache[layer_idx, batch_idx, :, pos] = v_new[i]
self.seq_lens[batch_indices] += 1
def get_all_heads(self, layer_idx, batch_indices=None):
"""
All-gather KV cache from all GPUs.
This is the expensive step that determines throughput.
"""
if batch_indices is None:
batch_indices = torch.arange(self.k_cache.shape[1])
max_seq_len = self.seq_lens[batch_indices].max().item()
k_local = self.k_cache[layer_idx, batch_indices, :, :max_seq_len]
v_local = self.v_cache[layer_idx, batch_indices, :, :max_seq_len]
# All-gather across GPUs
k_gathered = [torch.zeros_like(k_local) for _ in range(self.world_size)]
v_gathered = [torch.zeros_like(v_local) for _ in range(self.world_size)]
dist.all_gather(k_gathered, k_local)
dist.all_gather(v_gathered, v_local)
return torch.cat(k_gathered, dim=1), torch.cat(v_gathered, dim=1)
Every token generation requires all-gather. At 128K context with GQA, that’s 9.6GB transferred per token. Your network determines your throughput more than your compute does.
The Bandwidth Bottleneck
Token generation is where parallelism goes to die: there’s no fast path, no pipeline trick, no clever scheduler that gives you token N+1 before N exists. You need every step:
Forward pass through model (compute, ~50ms for 70B)
All-gather KV cache (bandwidth-bound)
Update local cache (memory write, ~5ms)
Network comparison at 128K context with GQA (9.6GB transfer per token):
networks = {
'NVLink 2.0': 25,
'NVLink 3.0': 50,
'NVLink 4.0 (H100)': 112.5,
'Ethernet 100GbE': 12.5,
'Thunderbolt 4': 5,
'Thunderbolt 5': 15,
}
kv_transfer = 9.6 # GB
compute_time = 0.050 # 50ms
for name, bw in networks.items():
transfer_time = kv_transfer / bw
total = compute_time + transfer_time
throughput = 1.0 / total
print(f"{name:<25} {throughput:>6.2f} tokens/sec")
Output:
NVLink 2.0 2.30 tokens/sec
NVLink 3.0 4.13 tokens/sec
NVLink 4.0 (H100) 7.39 tokens/sec
Ethernet 100GbE 1.22 tokens/sec
Thunderbolt 4 0.51 tokens/sec
Thunderbolt 5 1.45 tokens/sec
Thunderbolt 4 is unusable: 0.51 tokens/sec is slower than typing. Thunderbolt 5 is marginal at 128K context but becomes acceptable at 32K (4× less transfer = 6.3 tokens/sec). Ofc, NVLink 4.0 is production-viable even at extreme context lengths. But we already knew that.
The Inference Scaling Laws Nobody Publishes
Research papers show ideal scaling: 2× GPUs = 2× throughput.
Production systems see something else:
import math
def real_world_efficiency(num_gpus):
"""What actually happens in production inference clusters."""
# Communication overhead scales logarithmically
comm_factor = 1 - (0.15 * math.log2(num_gpus))
# Synchronization overhead (barrier waits, stragglers)
sync_factor = 1 - (0.08 * num_gpus / 8)
# Network congestion becomes severe >16 nodes
congestion_factor = 0.7 if num_gpus > 16 else 1.0
return comm_factor * sync_factor * congestion_factor
base_throughput = 20 # tokens/sec on single GPU
for n in [2, 4, 8, 16, 32]:
ideal = base_throughput * n
actual = base_throughput * n * real_world_efficiency(n)
eff = real_world_efficiency(n)
print(f"{n:2d} GPUs: {eff*100:4.1f}% efficient, "
f"{actual:5.1f} tok/s (ideal: {ideal:5.1f})")
Output:
2 GPUs: 83.0% efficient, 33.2 tok/s (ideal: 40.0)
4 GPUs: 67.0% efficient, 53.6 tok/s (ideal: 80.0)
8 GPUs: 52.0% efficient, 83.2 tok/s (ideal: 160.0)
16 GPUs: 37.0% efficient, 118.4 tok/s (ideal: 320.0)
32 GPUs: 18.9% efficient, 120.9 tok/s (ideal: 640.0)
NB. This is not a universal law, but a model that matches observed behavior in production clusters.
At 32 GPUs, you’re below 20% efficiency. But 120 tokens/sec is still 6× faster than 8 GPUs. Companies build massive clusters not because they’re efficient but because absolute throughput matters more than efficiency when you’re serving millions of requests. The economics justify 80% wasted capacity when the alternative is turning away users. Modern technology is about waste.
The Mac Mini Moment
November 20, 2025: MLX PR (#2808) for the RDMA backend (still open) December 12, 2025: macOS Tahoe 26.2 ships with RDMA over Thunderbolt 5.
Imminent: distributed inference on consumer hardware becomes possible.
Eight Mac Minis with M4 Max chips, 64GB each: 512GB total, 240GB/s aggregate, $24k. No quotas. No export restrictions. Hardware you own, on your desk, running 400B+ models with long context. Desktop inference cluster porn incoming. (Spoiler: it’s just a stack of minis)
On December 12, 2025, macOS Tahoe 26.2 shipped with RDMA over Thunderbolt 5, and MLX gained an RDMA backend. With that, low-latency remote memory access moved out of the datacenter and onto consumer hardware.
This is where I spare you the sidebar about cost optimization versus infrastructure sovereignty. Cloud H100 clusters are supply-constrained. Export controls limit access (though some controls were significantly relaxed this week.) Quota systems restrict usage. Mac Minis have none of these constraints. It’s not price per token that you’re optimizing (for) but straight up permission to iterate.
Remote Memory Becomes Usable
RDMA over Thunderbolt 5 gives you 15GB/s remote memory access with 100μs latency. Compare to local DRAM (200GB/s, 50ns) or NVMe SSDs (3.5GB/s, 100μs). The memory hierarchy just gained a new tier. Traditional inference assumes KV cache is local. This architecture accepts KV cache is remote but predictable.
Thunderbolt 5 is marginal at 128K context, but acceptable at 32K. Think of it as “slowly tolerable” rather than “fast.” Your cluster is doing work no single M4 Max could, but it’s reminding you who’s boss (physics.)
The bandwidth hierarchy inverted: remote unified memory over Thunderbolt 5 is now faster than local PCIe 3.0 storage! “Local” and “remote” stopped meaning what they meant in 2020. Network has become memory.
What You Actually Get
Performance at 32K context, batch size 8:
context_length = 32 * 1024
batch_size = 8
num_kv_heads = 8
head_dim = 128
kv_per_token = (
2 * num_kv_heads * context_length * head_dim * 2 * batch_size
) / 1024**3 # 2.56 GB
tb5_bandwidth = 15 * 2 * 0.8 # 2 ports, 80% efficiency = 24 GB/s
transfer_time = kv_per_token / tb5_bandwidth # 106.7 ms
compute_time = 0.080 # 80ms on M4 Max
total_time = compute_time + transfer_time # 186.7 ms
throughput = 1.0 / total_time # 5.36 tokens/sec
5.36 tokens/sec at 32K context. Not competitive with H100 clusters (7-10 tokens/sec) but sufficient for research, prototyping, and non-latency-critical production. Single M4 Max can’t run 70B at 32K, it OOMs before first token. The cluster makes impossible workloads possible, not fast workloads faster.
At 8 hours/day usage, Mac Minis break even in 3.8 months.
Cloud H100 clusters cost $26/hour (4× H100s). At 8 hours/day usage, Mac Minis break even in 3.8 months. After that, marginal cost is electricity (~$50/month for eight machines). No quota request. No availability zone selection. No “your region doesn’t support this instance type yet.”
What Actually Breaks
This is still rough around the edges, and the failure modes show up quickly:
Thermal throttling: Thunderbolt 5 cables get hot under sustained 15GB/s load. After 10 minutes, bandwidth drops to 12GB/s. After 20 minutes, 10GB/s. Your 5.36 tokens/sec becomes 4.1 tokens/sec. Active cooling on cables helps but you’re fighting physics.
macOS being macOS: Time Machine backups. Spotlight reindexing. Software update notifications. Background tasks you forgot were enabled. Your distributed inference cluster becomes a distributed annoyance. Single-user mode helps. Disabling everything helps more. Running Linux on Mac hardware would help most but MLX doesn’t support that yet.
Consensus failures: You’re generating token 847 of 1000. Mac Mini #3 kernel panics. Your all-gather hangs forever waiting for a node that will never respond. Distributed systems fail in creative ways. Checkpointing every N tokens mitigates this. Coordinator processes with health checks and retry logic solve it properly. But now you’re building distributed systems infrastructure, not just running inference.
The PR isn’t merged: As of December 13, 2025, RDMA over Thunderbolt is in PR #2808 is still open. It’s not production-ready. You’re a jungle pioneer, building on bleeding-edge code that might break between commits. In six months, when MLX RDMA is stable and documented, this becomes mundane. Right now, it’s an adventure.
Cloud inference is opaque as are its failures: you click “submit ticket” and wait, hoping someone else figures it out. But run it yourself, and every kernel panic, stalled all-gather, or thermal throttle is immediately visible. You break things, you fix things, you understand them. Distributed systems stop being abstract and start slapping hard.
The Sophistication Gradient
Right now, this setup gets you 5-6 tokens/sec with basic ring topology and crude all-gather primitives. In six months, when MLX RDMA is production-stable and someone implements proper multi-hop ring topologies with pipelined communication, it’ll be 8-10 tokens/sec. In twelve months, with Thunderbolt 6 (rumored 2026) and better unified memory primitives, maybe 15 tokens/sec. (Someone will price this. Never bet against prediction markets.)
You’re not so much purchasing hardware as buying a learning curve. The first distributed inference clusters will be crude. Six months from now they’ll be sophisticated. Twelve months from now they’ll be off-the-shelf kits, turnkey setup, production-grade reliability. But by then it won’t matter.
Right now, today, they’re possible. And possible is where every infrastructure shift begins. Someone will run the first 70B (or even 400B) model across eight Mac Minis this week. (w/w/o that merged PR.) Not because it’s optimal but merely bc it’s accessible; not because it’s fast (hardly) because it can be done.
Production Decision Tree
The general question “should I distribute?” is understood to mean “at what context length does my network bandwidth make distribution viable?” Single GPU is always fastest. Distribution is what you do when single GPU is impossible.
For 70B models: (8 minis can handle 400B+ models in theory)
<16K context: single GPU (don’t distribute)
16K-64K context: distribute across 2-4 GPUs with NVLink or Thunderbolt 5
64K-128K context: requires NVLink 4.0 or accept slower generation
128K context: Ring Attention (Unit 5) or wait for next-gen hardware
Network as Memory Hierarchy
In the 1970s, Jim Gray observed that disk was the new tape: orders of magnitude faster, requiring algorithmic rethinking. Sequential algorithms designed for tape failed catastrophically on disk’s random-access patterns. In the 1990s, RAM became the new disk. In-memory databases demolished decades of B-tree optimization wisdom.
In 2025, network became the new memory tier.
Memory hierarchy in 2025:
L1 Cache: 1 cycle, 1000+ GB/s
Local SRAM: 5 cycles, 500 GB/s
Local HBM: 300 cycles, 200 GB/s
Remote HBM via NVLink 4.0: 10,000 cycles, 112 GB/s
Remote unified memory via TB5: 50,000 cycles, 15 GB/s
Local NVMe (PCIe 3.0): 500,000 cycles, 3.5 GB/s
Each level is 10-100× slower than the previous. But look at the ordering: remote memory over Thunderbolt 5 is now faster than local NVMe storage! The memory hierarchy isn’t a strict local-to-remote gradient anymore. It’s a bandwidth-and-latency space where “remote” can outperform “local” depending on what you’re comparing.
KV-cache sharding is the algorithmic rethinking that Gray described. Traditional inference assumes KV cache is local (10ns access). Distributed inference accepts KV cache is remote (100μs access). That’s 10,000× slower, but it unlocks context lengths that don’t fit locally. The algorithm changes because the hardware changed.
𒅃 Jim Gray told us that “RAM is the new disk… Disk is the new tape.” This week a fundamental primitive effectively escaped the datacenter. Network is now the new memory tier. MLX has a working backend in PR. Remote memory over Thunderbolt 5 operates at 50,000 cycles and 15 GB/s. Local NVMe storage runs at 500,000 cycles and 3.5 GB/s. The hierarchy inverted. Remote unified memory is now faster than local storage. “Local” and “remote” stopped meaning what they used to.
Accordingly, the KV cache can no longer be thought of as a local detail. Grokking KV-cache sharding means understanding that the optimal implementation already exists in your network’s bandwidth characteristics. Your job is knowing those characteristics well enough to know when distribution helps and when it hurts. When 9.6GB all-gather takes 106ms, distribution enables very large (70-400B) inference. When it takes 1920ms, distribution makes inference unusable. Distributed inference forces the rethinking Gray described: accept that cache is remote (100μs access instead of 10ns), and in return, unlock context lengths that don’t fit locally. The algorithm changes because the substrate (hardware) revealed a different grain. This is τέχνη, revisited.
Up Next: Module II - Training Large Models
We’ve charted the crisis of memory and compute optimization. Next we turn our attention (literally) back to the architecture we are struggling to contain. A perfect device map is useless without a deep understanding of the layers it partitions. In Unit 4, we’ll build transformers from first principles and wire in these memory optimizations into a complete training loop, with multi-head attention, positional encodings, layer norms, and all the details papers gloss over. Part 1, Speculative Decoding will finally give you multiple tokens per forward pass, making distributed inference not just possible, but practually-actical. See you then!
← prev: Model Sharding Strategies — Next: Module II - Training Large Models →

