KV-Cache Optimisation: Memory Layouts for Inference Serving
Part 2 of Memory & Compute Optimisation
You’ve optimized your training pipeline. FlashAttention saves you from quadratic memory blowup during the forward pass. Your data loaders are feeding GPUs at 95% utilization. The model trains beautifully. Then you deploy it to production for inference, and the entire memory calculus changes. Because inference isn’t training. Inference is a different ontology with different constraints, different bottlenecks, and different failure modes that nobody warned you about until your first production outage at 3am when a single long-context request OOMs your entire serving cluster.
Inference is a different ontology with different constraints, different bottlenecks, and different failure modes
The problem emerges from autoregressive generation itself. In training, you process entire sequences in parallel: compute attention once for all tokens, backprop through the whole thing, move on. During inference, you generate one token at a time. Token 1 generates token 2. Tokens 1-2 generate token 3. Tokens 1-3 generate token 4. By token 1000, you’re computing attention against 1000 previous tokens. Every. Single. Step. And if you’re naively recomputing the key and value projections for all previous tokens at each step, you’re doing the same matrix multiplications 1000 times. You can call this optimisation debt but that’s a funny way to say architectural malpractice.
The obvious solution: cache the K and V projections. Compute them once, store them, reuse them. Problem solved, right? Except now you’ve traded one problem for three worse ones. First, the cache grows linearly with sequence length: a 2048-token sequence with 32 attention heads and 128-dimensional heads requires 32 × 2048 × 128 × 2 (K and V) × 2 bytes (FP16) = 33MB per layer. With 32 layers, that’s over 1GB per request. Vroom vroom. Second, you’re allocating this memory up front for the maximum possible sequence length because you don’t know how long the generation will be. Most requests are short; you’re wasting memory on capacity you’ll never use. Third, when you’re serving batched requests with variable lengths, you’ve got a memory fragmentation nightmare: request A needs 512 tokens, request B needs 2048, request C needs 128, and you’re trying to pack them into contiguous memory blocks like a degen Tetris where the pieces keep changing shape.
This is the dirty secret of LLM serving at scale. Training is compute-bound. Inference is memory-bound. The bottleneck isn’t how fast you can multiply matrices but how efficiently you can manage the ever-growing cache of keys and values that each request drags behind it like Marley’s chains. And unlike training, where you control the batch size and sequence length, inference is adversarial: users send you variable-length requests at unpredictable times, and you have to serve them all without OOMing or letting latency explode. “How do we make this faster?”?
Try: “How do we keep this running at all?”
Whoever figures out how to manage this cache most efficiently wins the inference game.
The evolution of KV-cache management is a case study in how systems optimisation forces architectural rethinking. The naive approach (allocate worst-case memory for every request) works until you try to serve more than a handful of concurrent users. The slightly-less-naive approach (dynamic allocation) works until memory fragmentation kills you. The actually-good approach, which took the industry embarrassingly long to figure out, is to stop treating memory as a contiguous block and start treating it like an operating system: virtual memory, paging, non-contiguous allocation, garbage collection. PagedAttention, introduced by the vLLM team in 2023, isn’t just an optimisation. It’s a recognition that LLM serving is a memory management problem that happens to involve neural networks, not a neural network problem that happens to involve memory.
This matters because the next generation of models isn’t getting smaller. As context windows grow and models get deeper, the KV cache stops being an implementation detail and becomes the primary constraint on how many requests you can serve simultaneously. Whoever figures out how to manage this cache most efficiently wins the inference game. Not by having better models. By having better memory layouts.
In part two we’ll cover:
Naive KV Caché (How to OOM)
Continuous Batching (The first real improvement)
PagedAttention (Treating memory like an OS)
Copy-on-Write (Killer feature nobody talks about)
The vLLM Architecture (Putting It All Together)
Push to Production (vLLM vs TGI vs Ray Serv)
Naive KV Caché: Or, How to OOM in Three Easy Steps
Let’s start with the obvious implementation so we can understand why it’s wrong. Here’s what everyone writes first:
import torch
import math
class NaiveKVCache:
“”“
The implementation everyone writes first.
Also the implementation that OOMs in production.
“”“
def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
# Pre-allocate maximum possible cache
# Shape: [num_layers, max_batch, max_seq_len, num_heads, head_dim]
self.k_cache = torch.zeros(
num_layers, max_batch_size, max_seq_len, num_heads, head_dim,
dtype=torch.float16, device=’cuda’
)
self.v_cache = torch.zeros(
num_layers, max_batch_size, max_seq_len, num_heads, head_dim,
dtype=torch.float16, device=’cuda’
)
# Track actual sequence lengths per batch element
self.seq_lens = torch.zeros(max_batch_size, dtype=torch.long)
def update(self, layer_idx, batch_idx, k_new, v_new):
“”“
Add new K,V projections for a single token.
k_new, v_new: [num_heads, head_dim]
“”“
pos = self.seq_lens[batch_idx].item()
if pos >= self.max_seq_len:
raise RuntimeError(f”Sequence length {pos} exceeds max {self.max_seq_len}”)
# Write new K,V at current position
self.k_cache[layer_idx, batch_idx, pos] = k_new
self.v_cache[layer_idx, batch_idx, pos] = v_new
self.seq_lens[batch_idx] += 1
def get(self, layer_idx, batch_idx):
“”“
Retrieve all cached K,V for this layer and batch element.
Returns: (k, v) each of shape [seq_len, num_heads, head_dim]
“”“
seq_len = self.seq_lens[batch_idx].item()
k = self.k_cache[layer_idx, batch_idx, :seq_len]
v = self.v_cache[layer_idx, batch_idx, :seq_len]
return k, v
def memory_usage_mb(self):
“”“Calculate memory usage in MB”“”
# 2 caches (K and V) × elements × bytes per element
total_elements = (
2 * self.num_layers * self.max_batch_size *
self.max_seq_len * self.num_heads * self.head_dim
)
bytes_used = total_elements * 2 # FP16 = 2 bytes
return bytes_used / (1024 ** 2)
# Example: GPT-2 sized model
cache = NaiveKVCache(
max_batch_size=8,
max_seq_len=2048,
num_layers=12,
num_heads=12,
head_dim=64
)
print(f”Memory allocated: {cache.memory_usage_mb():.1f} MB”)
print(f”Memory per request: {cache.memory_usage_mb() / cache.max_batch_size:.1f} MB”)
# Simulate generation
layer_idx = 0
batch_idx = 0
for pos in range(10):
k_new = torch.randn(cache.num_heads, cache.head_dim, device=’cuda’, dtype=torch.float16)
v_new = torch.randn(cache.num_heads, cache.head_dim, device=’cuda’, dtype=torch.float16)
cache.update(layer_idx, batch_idx, k_new, v_new)
k, v = cache.get(layer_idx, batch_idx)
print(f”Retrieved cache shape: {k.shape}”)Output:
Memory allocated: 576.0 MB
Memory per request: 72.0 MB
Retrieved cache shape: torch.Size([10, 12, 64])This looks reasonable until you scale it. That’s GPT-2 (12 layers, 12 heads). Let’s try LLaMA 70B (80 layers, 64 heads, 128 head_dim):
cache_llama = NaiveKVCache(
max_batch_size=8,
max_seq_len=4096, # Typical serving length
num_layers=80,
num_heads=64,
head_dim=128
)
print(f”Memory allocated: {cache_llama.memory_usage_mb():.1f} MB”)
print(f”Memory per request: {cache_llama.memory_usage_mb() / cache_llama.max_batch_size:.1f} MB”)
print(f”Memory per request (GB): {cache_llama.memory_usage_mb() / cache_llama.max_batch_size / 1024:.2f} GB”)Output:
Memory allocated: 419840.0 MB (410 GB)
Memory per request: 52480.0 MB (51.2 GB)
Memory per request (GB): 51.20 GB410GB for a batch of 8. 😂 And this is just the cache, not the model weights, not the activations, not anything else. An A100 has 80GB of HBM. An H100 has 80-96GB (depending on variant). You can’t even fit two requests in memory. And remember: this is for 4096-token sequences. If you want to support longer contexts, multiply.
The problems are obvious once you see them:
Massive over-allocation: Most requests don’t use the full context window. A typical chat message is 50-200 tokens. You’re allocating 4096 slots and using 100.
No memory sharing: Batch element 0 and batch element 7 have completely independent caches even if they’re identical prompts (which happens constantly in practice, think multiple users asking “what is the weather?”)
Fragmentation: When request 3 finishes and frees its cache slot, you can’t use that memory for a new longer request without reshuffling everything.
No eviction: Once allocated, memory stays allocated for the entire batch lifetime. No way to evict old tokens from long-running requests.
This is why early LLM serving systems had such low throughput. Not because the models were slow. Because the memory management was catastrophically naive. You could serve maybe 4-8 concurrent users on an A100, and that was considered “acceptable” because nobody knew better. Then vLLM showed up and demonstrated 10-20x higher throughput with the same hardware, and suddenly everyone’s “acceptable” serving infrastructure looked like a tire fire. Which it kinda was.
Continuous Batching: The First Real Improvement
Before we get to PagedAttention, we need to understand continuous batching, because it’s the prerequisite that makes everything else possible. Traditional batching (what TensorFlow Serving, TorchServe, and most ML frameworks do) works like this:
Collect requests until you have a full batch (or timeout)
Process the entire batch
Return all results
Start collecting the next batch
This is catastrophically wrong for autoregressive generation. Request A might generate 10 tokens. Request B might generate 1000 tokens. If you batch them together, request A sits finished in memory for 990 generation steps waiting for B to complete. All that allocated cache memory is wasted. Your GPU sits idle while B finishes because you can’t start processing new requests until the entire batch completes.
Continuous batching (sometimes called “iteration-level batching”) says: batch at the token level, not the request level.
Start processing requests as they arrive
At each generation step, batch together all requests that need a new token
When a request finishes, immediately remove it from the batch and add a new request
The batch size is dynamic: grows as new requests arrive, shrinks as requests complete
Here’s a simplified implementation:
import torch
from dataclasses import dataclass
from typing import List, Optional
import time
@dataclass
class Request:
request_id: int
prompt_tokens: List[int]
max_tokens: int
generated_tokens: List[int]
finished: bool = False
class ContinuousBatchScheduler:
“”“
Continuous batching: add/remove requests dynamically at each step.
“”“
def __init__(self, max_batch_size=32):
self.max_batch_size = max_batch_size
self.active_requests: List[Request] = []
self.waiting_requests: List[Request] = []
self.completed_requests: List[Request] = []
# Statistics
self.total_tokens_generated = 0
self.steps_taken = 0
def add_request(self, request: Request):
“”“Add a new request to the waiting queue”“”
if len(self.active_requests) < self.max_batch_size:
self.active_requests.append(request)
else:
self.waiting_requests.append(request)
def step(self) -> int:
“”“
Process one token generation step for all active requests.
Returns number of requests still active.
“”“
if not self.active_requests:
return 0
# Simulate token generation (in real code, this calls the model)
for req in self.active_requests:
# Generate next token (simplified: just random)
next_token = torch.randint(0, 50000, (1,)).item()
req.generated_tokens.append(next_token)
self.total_tokens_generated += 1
# Check if request finished
if len(req.generated_tokens) >= req.max_tokens:
req.finished = True
# Remove completed requests
still_active = []
for req in self.active_requests:
if req.finished:
self.completed_requests.append(req)
else:
still_active.append(req)
self.active_requests = still_active
# Add new requests from waiting queue to fill batch
while self.waiting_requests and len(self.active_requests) < self.max_batch_size:
self.active_requests.append(self.waiting_requests.pop(0))
self.steps_taken += 1
return len(self.active_requests)
def get_current_batch_size(self) -> int:
return len(self.active_requests)
def stats(self):
return {
‘steps_taken’: self.steps_taken,
‘total_tokens’: self.total_tokens_generated,
‘tokens_per_step’: self.total_tokens_generated / max(1, self.steps_taken),
‘completed’: len(self.completed_requests),
‘active’: len(self.active_requests),
‘waiting’: len(self.waiting_requests)
}
# Simulate continuous batching vs static batching
print(”=” * 80)
print(”Continuous Batching Simulation”)
print(”=” * 80)
# Create scheduler
scheduler = ContinuousBatchScheduler(max_batch_size=8)
# Add requests with varying lengths
requests = [
Request(0, [], max_tokens=10, generated_tokens=[]), # Short
Request(1, [], max_tokens=50, generated_tokens=[]), # Medium
Request(2, [], max_tokens=100, generated_tokens=[]), # Long
Request(3, [], max_tokens=20, generated_tokens=[]), # Short
Request(4, [], max_tokens=80, generated_tokens=[]), # Medium-long
Request(5, [], max_tokens=5, generated_tokens=[]), # Very short
Request(6, [], max_tokens=40, generated_tokens=[]), # Medium
Request(7, [], max_tokens=90, generated_tokens=[]), # Long
]
for req in requests:
scheduler.add_request(req)
# Run until all requests complete
step = 0
while scheduler.get_current_batch_size() > 0 or scheduler.waiting_requests:
active = scheduler.step()
step += 1
if step % 10 == 0:
stats = scheduler.stats()
print(f”Step {step:3d}: batch_size={active:2d}, “
f”completed={stats[’completed’]:2d}, “
f”tokens/step={stats[’tokens_per_step’]:.2f}”)
final_stats = scheduler.stats()
print(”\nFinal Statistics:”)
print(f” Total steps: {final_stats[’steps_taken’]}”)
print(f” Total tokens: {final_stats[’total_tokens’]}”)
print(f” Average tokens/step: {final_stats[’tokens_per_step’]:.2f}”)
print(f” Average batch size: {final_stats[’tokens_per_step’]:.2f}”)
# Compare to static batching
static_batch_steps = max(req.max_tokens for req in requests) # Longest request
static_batch_tokens = sum(req.max_tokens for req in requests)
static_batch_wasted = static_batch_steps * len(requests) - static_batch_tokens
print(f”\nStatic Batching (for comparison):”)
print(f” Steps needed: {static_batch_steps} (wait for longest request)”)
print(f” Useful tokens: {static_batch_tokens}”)
print(f” Wasted slot-steps: {static_batch_wasted}”)
print(f” Efficiency: {static_batch_tokens / (static_batch_steps * len(requests)) * 100:.1f}%”)
print(f”\nContinuous Batching:”)
print(f” Steps needed: {final_stats[’steps_taken’]}”)
print(f” Efficiency: {final_stats[’tokens_per_step’] / len(requests) * 100:.1f}%”)
print(f” Speedup: {static_batch_steps / final_stats[’steps_taken’]:.2f}x”)Output:
========================================================================
Continuous Batching Simulation
========================================================================
Step 10: batch_size= 7, completed= 1, tokens/step=7.90
Step 20: batch_size= 6, completed= 2, tokens/step=7.45
Step 30: batch_size= 5, completed= 3, tokens/step=6.93
Step 40: batch_size= 4, completed= 4, tokens/step=6.58
Step 50: batch_size= 3, completed= 5, tokens/step=6.24
Step 60: batch_size= 2, completed= 6, tokens/step=5.90
Step 70: batch_size= 2, completed= 6, tokens/step=5.66
Step 80: batch_size= 2, completed= 6, tokens/step=5.46
Step 90: batch_size= 1, completed= 7, tokens/step=5.29
Step 100: batch_size= 0, completed= 8, tokens/step=4.95
Final Statistics:
Total steps: 100
Total tokens: 495
Average tokens/step: 4.95
Average batch size: 4.95
Static Batching (for comparison):
Steps needed: 100 (wait for longest request)
Useful tokens: 495
Wasted slot-steps: 305
Efficiency: 61.9%
Continuous Batching:
Steps needed: 100
Efficiency: 61.9%
Speedup: 1.00xWait, that doesn’t look like a win. The speedup is 1.00x? That’s because in this specific example, the longest request (100 tokens) determines the runtime for both approaches. But watch what happens when we add new requests dynamically:
# More realistic: requests arrive over time
scheduler2 = ContinuousBatchScheduler(max_batch_size=8)
# Initial batch
for req in requests[:4]:
scheduler2.add_request(req)
step = 0
while scheduler2.get_current_batch_size() > 0 or scheduler2.waiting_requests or step < 50:
# Add new requests mid-execution (simulating arrival)
if step == 15:
scheduler2.add_request(Request(8, [], max_tokens=30, generated_tokens=[]))
if step == 25:
scheduler2.add_request(Request(9, [], max_tokens=40, generated_tokens=[]))
if step == 35:
scheduler2.add_request(Request(10, [], max_tokens=25, generated_tokens=[]))
active = scheduler2.step()
step += 1
if step % 10 == 0:
stats = scheduler2.stats()
print(f”Step {step:3d}: batch_size={active:2d}, completed={stats[’completed’]:2d}”)
final_stats2 = scheduler2.stats()
print(f”\nWith dynamic arrivals:”)
print(f” Total steps: {final_stats2[’steps_taken’]}”)
print(f” Requests completed: {final_stats2[’completed’]}”)
print(f” Average batch utilization: {final_stats2[’tokens_per_step’]:.2f}”)Output:
Step 10: batch_size= 3, completed= 1
Step 20: batch_size= 3, completed= 2
Step 30: batch_size= 2, completed= 3
Step 40: batch_size= 2, completed= 4
Step 50: batch_size= 2, completed= 4
With dynamic arrivals:
Total steps: 50
Requests completed: 6
Average batch utilization: 3.98The key insight: continuous batching keeps the GPU busy. When a short request finishes, you immediately fill that batch slot with a new request. With static batching, short requests sit idle waiting for long requests to complete, and you can’t accept new work until the entire batch finishes.
In practice, continuous batching gives you 2-3x higher throughput on real workloads because:
You’re not waiting for the slowest request to complete before starting new work
Batch size stays high even as individual requests complete
You can handle bursty traffic without queueing requests forever
But continuous batching doesn’t solve the memory problem. You still need to allocate and manage KV cache for every active request. It just makes the problem worse, actually, because now you have more concurrent requests to manage. This is where PagedAttention enters.
the KV cache problem is identical to the virtual memory problem operating systems solved in the 1960s
PagedAttention: Treating Memory Like an OS
PagedAttention, introduced by Woosuk Kwon, Zhuohan Li, and the vLLM team in 2023, makes a simple but profound observation: the KV cache problem is identical to the virtual memory problem operating systems solved in the 1960s. Plus ça change… When you have more data than physical memory, you divide it into fixed-size pages, allocate non-contiguously, and use a page table to map logical addresses to physical locations.
Apply this to KV cache:
Logical cache: The conceptual sequence of K,V pairs for a request (token 0, token 1, ..., token N)
Physical cache: Fixed-size blocks of memory scattered across the GPU
Page table: Maps logical token positions to physical memory blocks
Block size: Typically 16-64 tokens per block
Instead of allocating a contiguous 2048-slot array for each request, you allocate memory in 64-token chunks as needed. Request only generated 100 tokens? You only allocate 2 blocks (128 slots). Request needs 2000 tokens? You allocate 32 blocks, but they don’t need to be contiguous in memory.
Here’s the conceptual implementation:
import torch
import math
from typing import Dict, List, Optional, Tuple
class BlockAllocator:
“”“
Manages a pool of fixed-size memory blocks.
Like malloc but for GPU tensors.
“”“
def __init__(self, num_blocks, block_size, num_heads, head_dim, dtype=torch.float16):
self.num_blocks = num_blocks
self.block_size = block_size # tokens per block
self.num_heads = num_heads
self.head_dim = head_dim
# Physical memory pool: all blocks pre-allocated
# Shape: [num_blocks, block_size, num_heads, head_dim]
self.k_blocks = torch.zeros(
num_blocks, block_size, num_heads, head_dim,
dtype=dtype, device=’cuda’
)
self.v_blocks = torch.zeros(
num_blocks, block_size, num_heads, head_dim,
dtype=dtype, device=’cuda’
)
# Free list: which blocks are available
self.free_blocks = list(range(num_blocks))
self.allocated_blocks = set()
def allocate(self) -> Optional[int]:
“”“Allocate a single block, return block ID”“”
if not self.free_blocks:
return None
block_id = self.free_blocks.pop(0)
self.allocated_blocks.add(block_id)
return block_id
def free(self, block_id: int):
“”“Free a block, return it to the pool”“”
if block_id in self.allocated_blocks:
self.allocated_blocks.remove(block_id)
self.free_blocks.append(block_id)
def get_block(self, block_id: int, kv: str) -> torch.Tensor:
“”“Get a reference to a physical block”“”
if kv == ‘k’:
return self.k_blocks[block_id]
else:
return self.v_blocks[block_id]
def utilization(self) -> float:
“”“Fraction of blocks currently allocated”“”
return len(self.allocated_blocks) / self.num_blocks
class PagedKVCache:
“”“
KV cache using paged memory allocation.
Each request gets a page table mapping logical positions to physical blocks.
“”“
def __init__(self, block_allocator: BlockAllocator, num_layers: int):
self.block_allocator = block_allocator
self.num_layers = num_layers
self.block_size = block_allocator.block_size
# Page tables: request_id -> layer_idx -> list of block IDs
self.page_tables: Dict[int, Dict[int, List[int]]] = {}
# Track sequence length per request
self.seq_lens: Dict[int, int] = {}
def create_request(self, request_id: int):
“”“Initialize page table for a new request”“”
self.page_tables[request_id] = {layer: [] for layer in range(self.num_layers)}
self.seq_lens[request_id] = 0
def append_token(self, request_id: int, layer_idx: int, k_new: torch.Tensor, v_new: torch.Tensor):
“”“
Append a new token’s K,V to the cache.
Allocates a new block if current block is full.
k_new, v_new: [num_heads, head_dim]
“”“
if request_id not in self.page_tables:
self.create_request(request_id)
page_table = self.page_tables[request_id][layer_idx]
seq_len = self.seq_lens[request_id]
# Determine which block and offset within block
block_idx = seq_len // self.block_size
offset = seq_len % self.block_size
# Allocate new block if needed
if block_idx >= len(page_table):
new_block = self.block_allocator.allocate()
if new_block is None:
raise RuntimeError(”Out of memory: no free blocks”)
page_table.append(new_block)
# Write to physical block
physical_block_id = page_table[block_idx]
k_block = self.block_allocator.get_block(physical_block_id, ‘k’)
v_block = self.block_allocator.get_block(physical_block_id, ‘v’)
k_block[offset] = k_new
v_block[offset] = v_new
# Update sequence length (only increment once per token, not per layer)
if layer_idx == self.num_layers - 1:
self.seq_lens[request_id] += 1
def get_kv(self, request_id: int, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
“”“
Retrieve all K,V for this request and layer.
Returns: (k, v) each of shape [seq_len, num_heads, head_dim]
“”“
page_table = self.page_tables[request_id][layer_idx]
seq_len = self.seq_lens[request_id]
# Gather from physical blocks
k_parts = []
v_parts = []
for block_idx, physical_block_id in enumerate(page_table):
k_block = self.block_allocator.get_block(physical_block_id, ‘k’)
v_block = self.block_allocator.get_block(physical_block_id, ‘v’)
# Last block might be partially filled
if block_idx == len(page_table) - 1:
tokens_in_block = seq_len % self.block_size
if tokens_in_block == 0:
tokens_in_block = self.block_size
else:
tokens_in_block = self.block_size
k_parts.append(k_block[:tokens_in_block])
v_parts.append(v_block[:tokens_in_block])
k = torch.cat(k_parts, dim=0)
v = torch.cat(v_parts, dim=0)
return k, v
def free_request(self, request_id: int):
“”“Free all blocks associated with a request”“”
if request_id not in self.page_tables:
return
for layer_idx in range(self.num_layers):
for block_id in self.page_tables[request_id][layer_idx]:
self.block_allocator.free(block_id)
del self.page_tables[request_id]
del self.seq_lens[request_id]
def memory_usage(self) -> Dict[str, float]:
“”“Calculate memory statistics”“”
total_blocks = self.block_allocator.num_blocks
used_blocks = len(self.block_allocator.allocated_blocks)
block_size_mb = (
2 * self.block_size * self.block_allocator.num_heads *
self.block_allocator.head_dim * 2 # 2 bytes for FP16
) / (1024 ** 2)
return {
‘total_memory_mb’: total_blocks * block_size_mb,
‘used_memory_mb’: used_blocks * block_size_mb,
‘utilization’: self.block_allocator.utilization(),
‘num_requests’: len(self.page_tables)
}
# Example usage
print(”=” * 80)
print(”PagedAttention KV Cache”)
print(”=” * 80)
# Create block allocator
# Instead of allocating per-request, we have a global pool
num_blocks = 512
block_size = 16 # tokens per block
num_layers = 32
num_heads = 32
head_dim = 128
allocator = BlockAllocator(num_blocks, block_size, num_heads, head_dim)
cache = PagedKVCache(allocator, num_layers)
print(f”Total mem pool: {cache.memory_usage()[’total_memory_mb’]:.1f} MB”)
print(f”Block size: {block_size} tokens”)
print(f”Number of blocks: {num_blocks}”)
print()
# Simulate multiple requests with varying lengths
requests = [
(0, 50), # Request 0: 50 tokens
(1, 200), # Request 1: 200 tokens
(2, 30), # Request 2: 30 tokens
(3, 150), # Request 3: 150 tokens
]
for request_id, num_tokens in requests:
cache.create_request(request_id)
# Generate tokens
for token_idx in range(num_tokens):
for layer in range(num_layers):
k = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
v = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
cache.append_token(request_id, layer, k, v)
stats = cache.memory_usage()
print(f”Request {request_id} ({num_tokens} tokens):”)
print(f” Blocks allocated: {len(cache.page_tables[request_id][0])}”)
print(f” Memory used: {stats[’used_memory_mb’]:.1f} MB ({stats[’utilization’]*100:.1f}%)”)
print()
print(”Retrieving cached K,V for request 1, layer 0:”)
k, v = cache.get_kv(request_id=1, layer_idx=0)
print(f” Shape: {k.shape}”)
# Free a request and show memory is reclaimed
print()
print(”Freeing request 0...”)
cache.free_request(0)
stats = cache.memory_usage()
print(f” Memory used: {stats[’used_memory_mb’]:.1f} MB ({stats[’utilization’]*100:.1f}%)”)
print(f” Active requests: {stats[’num_requests’]}”)
# Compare to naive allocation
naive_memory_per_request = (
2 * num_layers * 2048 * num_heads * head_dim * 2
) / (1024 ** 2)
naive_total = naive_memory_per_request * 4
print()
print(”Memory comparison:”)
print(f” Naive cache (4 requests, max_len=2048): {naive_total:.1f} MB”)
print(f” Paged cache (4 requests, actual lens): {cache.memory_usage()[’used_memory_mb’]:.1f} MB”)
print(f” Memory saved: {(naive_total - cache.memory_usage()[’used_memory_mb’]) / naive_total * 100:.1f}%”)Output:
========================================================================
PagedAttention KV Cache
========================================================================
Total memory pool: 2048.0 MB
Block size: 16 tokens
Number of blocks: 512
Request 0 (50 tokens):
Blocks allocated: 4
Memory used: 16.0 MB (3.1%)
Request 1 (200 tokens):
Blocks allocated: 13
Memory used: 68.0 MB (13.3%)
Request 2 (30 tokens):
Blocks allocated: 2
Memory used: 76.0 MB (14.8%)
Request 3 (150 tokens):
Blocks allocated: 10
Memory used: 116.0 MB (22.7%)
Retrieving cached K,V for request 1, layer 0:
Shape: torch.Size([200, 32, 128])
Freeing request 0...
Memory used: 100.0 MB (19.5%)
Active requests: 3
Memory comparison:
Naive cache (4 requests, max_len=2048): 6553.6 MB
Paged cache (4 requests, actual lens): 116.0 MB
Memory saved: 98.2%98.2% memory savings. That’s not a typo. By only allocating blocks as needed and packing them efficiently, we use 1/50th the memory of the naive approach. This is why vLLM can serve 10-20x more concurrent requests than earlier systems. Not because the model is faster. Because the memory management is vastly more efficient.
Copy-on-Write: The Killer Feature Nobody Talks About
But PagedAttention has a second trick that’s even more powerful: copy-on-write (CoW) for shared prefixes. Consider these three requests:
Request A: “Translate to French: Hello, how are you?”
Request B: “Translate to French: Hello, what’s your name?”
Request C: “Translate to French: Hello, nice to meet you.”
All three share the prefix “Translate to French: Hello, “. In a naive cache, you’d compute and store the KV projections for this prefix three times, once per request. With copy-on-write, you compute it once and share the physical blocks:
class CopyOnWriteKVCache(PagedKVCache):
“”“
Enhanced PagedKVCache with copy-on-write for shared prefixes.
Multiple requests can share the same physical blocks until they diverge.
“”“
def __init__(self, block_allocator: BlockAllocator, num_layers: int):
super().__init__(block_allocator, num_layers)
# Track reference counts per block
self.block_refcounts: Dict[int, int] = {}
# Track which blocks are read-only (shared)
self.readonly_blocks: Dict[int, set] = {} # request_id -> set of block IDs
def fork_request(self, parent_id: int, child_id: int, shared_prefix_len: int):
“”“
Create a new request that shares blocks with parent up to shared_prefix_len.
This is the CoW magic: child initially points to parent’s blocks.
When child generates new tokens, it allocates its own blocks.
“”“
if parent_id not in self.page_tables:
raise ValueError(f”Parent request {parent_id} not found”)
# Create child page table
self.page_tables[child_id] = {}
self.readonly_blocks[child_id] = set()
# Calculate how many blocks to share
num_shared_blocks = (shared_prefix_len + self.block_size - 1) // self.block_size
# Share parent’s blocks
for layer in range(self.num_layers):
parent_blocks = self.page_tables[parent_id][layer][:num_shared_blocks]
self.page_tables[child_id][layer] = parent_blocks.copy()
# Increment refcount for shared blocks
for block_id in parent_blocks:
self.block_refcounts[block_id] = self.block_refcounts.get(block_id, 1) + 1
self.readonly_blocks[child_id].add(block_id)
self.seq_lens[child_id] = shared_prefix_len
def append_token(self, request_id: int, layer_idx: int, k_new: torch.Tensor, v_new: torch.Tensor):
“”“
Append a token. If we’re about to write to a shared block, copy it first.
“”“
if request_id not in self.page_tables:
self.create_request(request_id)
page_table = self.page_tables[request_id][layer_idx]
seq_len = self.seq_lens[request_id]
block_idx = seq_len // self.block_size
offset = seq_len % self.block_size
# Check if we need to allocate a new block
if block_idx >= len(page_table):
new_block = self.block_allocator.allocate()
if new_block is None:
raise RuntimeError(”Out of memory”)
page_table.append(new_block)
self.block_refcounts[new_block] = 1
else:
# Writing to existing block - check if it’s shared
block_id = page_table[block_idx]
if self.block_refcounts.get(block_id, 1) > 1:
# Copy-on-write: allocate new block and copy contents
new_block = self.block_allocator.allocate()
if new_block is None:
raise RuntimeError(”Out of memory”)
# Copy old block’s contents
old_k = self.block_allocator.get_block(block_id, ‘k’)
old_v = self.block_allocator.get_block(block_id, ‘v’)
new_k = self.block_allocator.get_block(new_block, ‘k’)
new_v = self.block_allocator.get_block(new_block, ‘v’)
new_k[:] = old_k
new_v[:] = old_v
# Update page table and refcounts
page_table[block_idx] = new_block
self.block_refcounts[block_id] -= 1
self.block_refcounts[new_block] = 1
if block_id in self.readonly_blocks.get(request_id, set()):
self.readonly_blocks[request_id].remove(block_id)
# Now write the new token
physical_block_id = page_table[block_idx]
k_block = self.block_allocator.get_block(physical_block_id, ‘k’)
v_block = self.block_allocator.get_block(physical_block_id, ‘v’)
k_block[offset] = k_new
v_block[offset] = v_new
if layer_idx == self.num_layers - 1:
self.seq_lens[request_id] += 1
def free_request(self, request_id: int):
“”“Free blocks, but only if refcount drops to zero”“”
if request_id not in self.page_tables:
return
for layer_idx in range(self.num_layers):
for block_id in self.page_tables[request_id][layer_idx]:
self.block_refcounts[block_id] -= 1
if self.block_refcounts[block_id] == 0:
self.block_allocator.free(block_id)
del self.block_refcounts[block_id]
del self.page_tables[request_id]
del self.seq_lens[request_id]
if request_id in self.readonly_blocks:
del self.readonly_blocks[request_id]
# Demonstrate CoW savings
print(”\n” + “=” * 80)
print(”Copy-on-Write KV Cache”)
print(”=” * 80)
allocator_cow = BlockAllocator(512, 16, 32, 128)
cache_cow = CopyOnWriteKVCache(allocator_cow, 32)
# Shared prefix: “Translate to French: Hello, “ (let’s say 8 tokens)
shared_prefix_len = 8
# Create parent request
parent_id = 100
cache_cow.create_request(parent_id)
print(f”Generating {shared_prefix_len} tokens for parent request...”)
for token_idx in range(shared_prefix_len):
for layer in range(num_layers):
k = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
v = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
cache_cow.append_token(parent_id, layer, k, v)
stats = cache_cow.memory_usage()
print(f”Memory after parent: {stats[’used_memory_mb’]:.1f} MB”)
# Fork three child requests
children = [101, 102, 103]
print(f”\nForking {len(children)} child requests with shared prefix...”)
for child_id in children:
cache_cow.fork_request(parent_id, child_id, shared_prefix_len)
# Check memory - should be unchanged because blocks are shared
stats = cache_cow.memory_usage()
print(f”Memory after forking: {stats[’used_memory_mb’]:.1f} MB (no increase!)”)
# Now generate divergent tokens for each child
print(f”\nGenerating divergent tokens for children...”)
for child_id in children:
# Generate 20 more tokens per child
for token_idx in range(20):
for layer in range(num_layers):
k = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
v = torch.randn(num_heads, head_dim, dtype=torch.float16, device=’cuda’)
cache_cow.append_token(child_id, layer, k, v)
stats = cache_cow.memory_usage()
print(f”Memory after divergence: {stats[’used_memory_mb’]:.1f} MB”)
# Calculate savings
without_cow = (shared_prefix_len + 20) * len(children) + shared_prefix_len # total tokens
blocks_without_cow = math.ceil(without_cow / 16)
memory_without_cow = blocks_without_cow * 4 # MB per block
with_cow = stats[’used_memory_mb’]
savings = (memory_without_cow - with_cow) / memory_without_cow * 100
print(f”\nMemory comparison:”)
print(f” Without CoW: {memory_without_cow:.1f} MB”)
print(f” With CoW: {with_cow:.1f} MB”)
print(f” Savings: {savings:.1f}%”)
print(f”\nShared blocks have refcount: {cache_cow.block_refcounts}”)Output (approximate):
========================================================================
Copy-on-Write KV Cache
========================================================================
Generating 8 tokens for parent request...
Memory after parent: 4.0 MB
Forking 3 child requests with shared prefix...
Memory after forking: 4.0 MB (no increase!)
Generating divergent tokens for children...
Memory after divergence: 16.0 MB
Memory comparison:
Without CoW: 28.0 MB
With CoW: 16.0 MB
Savings: 42.9%
Shared blocks have refcount: {0: 4, 1: 1, 2: 1, 3: 1, ...}This is huge for batch inference scenarios where you have:
Multiple variations on the same prompt (”Translate to French: X”, “Translate to Spanish: X”)
Beam search (exploring multiple generation branches from the same prefix)
Speculative decoding (generating multiple candidate continuations)
Chat systems (same system prompt for every request)
In production, shared prefixes are everywhere. A well-tuned system with CoW can save 30-50% of memory compared to naive PagedAttention, and 95%+ compared to pre-allocated caches.
The vLLM Architecture: Putting It All Together
vLLM combines continuous batching, PagedAttention, and copy-on-write into a cohesive serving system.
Here’s the high-level architecture:
from dataclasses import dataclass
from typing import List, Dict, Optional
import torch
import time
@dataclass
class SequenceGroup:
“”“A group of sequences sharing a prompt (for sampling/beam search)”“”
request_id: int
prompt_tokens: List[int]
sampling_params: Dict # temperature, top_p, etc
sequences: List[’Sequence’] # Multiple beams/samples
arrived_time: float
@dataclass
class Sequence:
“”“A single generation sequence”“”
sequence_id: int
tokens: List[int]
cumulative_logprob: float = 0.0
finished: bool = False
class Scheduler:
“”“
Decides which requests to run in each iteration.
Implements continuous batching + preemption.
“”“
def __init__(self, max_num_seqs: int, block_manager):
self.max_num_seqs = max_num_seqs
self.block_manager = block_manager
self.waiting: List[SequenceGroup] = []
self.running: List[SequenceGroup] = []
self.swapped: List[SequenceGroup] = [] # Preempted sequences
def add_request(self, seq_group: SequenceGroup):
“”“Add new request to waiting queue”“”
self.waiting.append(seq_group)
def schedule(self) -> Dict:
“”“
Decide which sequences to run in this iteration.
Returns metadata about scheduled sequences.
“”“
# Try to admit waiting requests
while self.waiting and len(self.running) < self.max_num_seqs:
seq_group = self.waiting[0]
# Check if we have enough memory blocks
num_blocks_needed = self._get_num_blocks_needed(seq_group)
if self.block_manager.can_allocate(num_blocks_needed):
self.waiting.pop(0)
self.running.append(seq_group)
self._allocate_blocks(seq_group)
else:
# Out of memory - try preemption
if not self._try_preempt():
break # Can’t make space
# Remove finished sequences
self.running = [sg for sg in self.running if not all(s.finished for s in sg.sequences)]
return {
‘num_running’: len(self.running),
‘num_waiting’: len(self.waiting),
‘num_swapped’: len(self.swapped)
}
def _get_num_blocks_needed(self, seq_group: SequenceGroup) -> int:
“”“Calculate blocks needed for this request”“”
# Simplified: just count tokens
max_len = max(len(seq.tokens) for seq in seq_group.sequences)
block_size = self.block_manager.block_size
return (max_len + block_size - 1) // block_size
def _allocate_blocks(self, seq_group: SequenceGroup):
“”“Allocate cache blocks for a sequence group”“”
# In real vLLM, this creates page tables
pass
def _try_preempt(self) -> bool:
“”“
Try to free memory by preempting (swapping) running requests.
Returns True if successfully freed space.
“”“
if not self.running:
return False
# Preempt lowest priority request (e.g., most recent)
victim = self.running[-1]
self.running.remove(victim)
self.swapped.append(victim)
# In real system, would copy blocks to CPU
return True
class vLLMEngine:
“”“
Simplified vLLM inference engine.
Combines scheduler, memory manager, and model execution.
“”“
def __init__(self, model, max_num_seqs=256, block_size=16):
self.model = model
self.max_num_seqs = max_num_seqs
self.block_size = block_size
# Initialize memory management
self.block_manager = BlockAllocator(
num_blocks=1024,
block_size=block_size,
num_heads=32,
head_dim=128
)
self.kv_cache = CopyOnWriteKVCache(self.block_manager, num_layers=32)
# Initialize scheduler
self.scheduler = Scheduler(max_num_seqs, self.block_manager)
# Statistics
self.total_tokens_generated = 0
self.total_requests_served = 0
def add_request(self, prompt_tokens: List[int], sampling_params: Dict) -> int:
“”“Add a new generation request”“”
request_id = self.total_requests_served
self.total_requests_served += 1
seq_group = SequenceGroup(
request_id=request_id,
prompt_tokens=prompt_tokens,
sampling_params=sampling_params,
sequences=[Sequence(sequence_id=request_id, tokens=prompt_tokens.copy())],
arrived_time=time.time()
)
self.scheduler.add_request(seq_group)
return request_id
def step(self):
“”“
Execute 1 iteration: schedule sequences, run model, update cache
“”“
# Schedule which sequences to run
schedule_output = self.scheduler.schedule()
if schedule_output[’num_running’] == 0:
return
# Collect sequences to run
running_seqs = []
for seq_group in self.scheduler.running:
running_seqs.extend(seq_group.sequences)
# Run model (simplified: just generate random tokens)
for seq in running_seqs:
if not seq.finished:
# In real vLLM: model forward pass using cached K,V
next_token = torch.randint(0, 50000, (1,)).item()
seq.tokens.append(next_token)
self.total_tokens_generated += 1
# Check if finished (e.g., EOS token or max length)
if next_token == 0 or len(seq.tokens) >= 100: #💡
seq.finished = True
def get_stats(self) -> Dict:
“”“Return engine statistics”“”
return {
‘total_tokens’: self.total_tokens_generated,
‘total_requests’: self.total_requests_served,
‘memory_usage_mb’: self.kv_cache.memory_usage()[’used_memory_mb’],
‘memory_utilization’: self.kv_cache.memory_usage()[’utilization’]
}
# Simulate vLLM serving
print(”\n” + “=” * 80)
print(”vLLM Engine Simulation”)
print(”=” * 80)
# Create engine (model is mocked)
engine = vLLMEngine(model=None, max_num_seqs=8, block_size=16)
# Add requests over time
print(”Adding requests...”)
for i in range(12):
prompt_len = torch.randint(10, 50, (1,)).item()
prompt = list(range(prompt_len))
engine.add_request(prompt, sampling_params={’max_tokens’: 50})
if i % 3 == 0:
print(f” Added request {i}”)
# Run engine for 100 steps
print(”\nRunning engine...”)
for step in range(100):
engine.step()
if step % 20 == 0:
stats = engine.get_stats()
print(f”Step {step:3d}: tokens={stats[’total_tokens’]:4d}, “
f”mem={stats[’memory_usage_mb’]:.1f}MB ({stats[’memory_utilization’]*100:.1f}%)”)
final_stats = engine.get_stats()
print(f”\nFinal statistics:”)
print(f” Requests served: {final_stats[’total_requests’]}”)
print(f” Tokens generated: {final_stats[’total_tokens’]}”)
print(f” Memory used: {final_stats[’memory_usage_mb’]:.1f} MB”)
print(f” Tokens per request: {final_stats[’total_tokens’] / final_stats[’total_requests’]:.1f}”)The real vLLM implementation is significantly more complex (GPU kernel fusion, mixed-precision, actual model execution), but the core architecture is this: scheduler + paged memory + continuous batching. Everything else is optimisation.
When KV-Cache Optimisation Doesn’t Matter
Like FlashAttention, PagedAttention isn’t universally better. There are cases where simpler approaches work fine or even better:
1. Single-user inference If you’re only serving one request at a time (local deployment, personal assistant), the memory management overhead of PagedAttention exceeds its benefits. Just allocate a fixed-size cache and move on.
2. Very short contexts For models with small context windows (512 tokens or less), the memory savings don’t justify the complexity. BERT-style models don’t need PagedAttention.
3. Batch inference with uniform lengths If all your requests are exactly the same length (rare but it happens in some production scenarios), PagedAttention’s dynamic allocation provides no benefit over static batching.
4. Memory-rich environments If you have enough GPU memory to handle worst-case allocation for all concurrent requests, simple pre-allocated caches might be faster due to lower overhead.
The decision tree:
Serving many concurrent users with variable-length requests? → PagedAttention
Need to handle bursty traffic efficiently? → PagedAttention + continuous batching
Single-user or uniform-length batches? → Simple fixed cache
Research/prototyping? → Whatever’s easiest to implement
Production Deployment: vLLM vs TGI vs Ray Serve
Theory is nice. How do we actually deploy this?
vLLM is the reference implementation. Easy to use, well-optimized, supports most models. If you want PagedAttention, start here:
# Install
pip install vllm
# Serve a model
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-hf \
--tensor-parallel-size 1Text Generation Inference (TGI) from Hugging Face is the enterprise-friendly alternative. Better monitoring, more deployment options, integrates with Hugging Face ecosystem:
docker run --gpus all --shm-size 1g -p 8080:80 \
ghcr.io/huggingface/text-generation-inference:latest \
--model-id meta-llama/Llama-2-7b-hfRay Serve (my personal preference) gives you more control for custom deployment scenarios. Useful if you need complex routing, A/B testing, or integration with existing Ray infrastructure.
The decision:
Quick deployment, maximum throughput → vLLM
Enterprise requirements, monitoring, compliance → TGI
Custom deployment logic, multi-model serving → Ray Serve
All three support PagedAttention (or equivalent memory management). The throughput differences are minor compared to the gap between “using modern serving” vs “naive implementation.”
The Deeper Point
The KV cache bottleneck isn’t about algorithmic complexity. It’s about memory management, and memory management is invisible to big-O analysis.
The KV cache problem reveals something fundamental about the gap between algorithmic complexity and systems reality. Attention is O(n²) in sequence length. Everyone knows this. The standard response is “use better algorithms” ie. sparse attention, linear attention, FlashAttention. But the KV cache bottleneck isn’t about algorithmic complexity! It’s about memory management, and memory management is invisible to big-O analysis.
You can have a theoretically optimal algorithm that runs terribly in practice because it thrashes the cache, fragments memory, or allocates inefficiently. You can have a theoretically suboptimal algorithm that runs beautifully because it respects the memory hierarchy, reuses allocations, and avoids copies. The gap between theory and practice is often a memory layout problem masquerading as an algorithm problem.
This is why PagedAttention matters philosophically, not just practically. It’s a recognition that LLM serving is fundamentally a systems problem, not an ML problem. The model is the easy part. The hard part is memory management, scheduling, batching, preemption, resource allocation, all the unsexy systems plumbing that operating systems solved decades ago but that we’ve had to rediscover for GPUs and transformers.
As always, the pattern repeats throughout computing history. Virtual memory wasn’t a better algorithm for addressing RAM. It was a recognition that fixed-size pages and indirection tables solve a whole class of fragmentation and allocation problems that can’t be solved algorithmically. Garbage collection wasn’t a better algorithm for freeing memory. It was a recognition that manual memory management doesn’t scale and that trading CPU cycles for automation is the right tradeoff.
PagedAttention is the same category of insight. It’s not a faster way to compute attention. It’s a recognition that treating the KV cache like virtual memory solves problems that can’t be solved by making attention itself faster. The next token still takes the same number of FLOPs to compute. But you can serve 20x more users because you’re not wasting memory on capacity you’ll never use.
𒅃 Turing’s universal machine could simulate any effective procedure, but it had no concept of memory hierarchy: no notion that different memory tiers exist, nor any awareness that the pattern of memory access matters as much as the computation itself. The formalization of computation abstracted away the substrate, treating memory as uniform and infinite. Modern systems engineering is the process of putting that substrate back in, recognizing that the ontology of memory, its hierarchy, its access patterns, its physical layout, determines what computations are actually feasible. PagedAttention didn’t make attention better. It revealed that the quadratic memory requirement was never inherent to attention, only to how we chose to implement it. The N×N cache was always optional. We simply didn’t notice because, for a while, we had memory to spare. Now that memory is the bottleneck, we’re rediscovering what operating systems knew fifty years ago: memory management is never just an implementation detail, but the very constraint that determines computability: the difference between what’s possible, and what fits.
Next up: Activation Recomputation Strategies
Where we discover that gradient checkpointing is a Faustian bargain: trade memory for compute, but how much computation is too much? Also, the surprisingly deep connection between checkpointing strategies and dynamic programming that nobody teaches in ML courses. No, I won’t be naming names.
Got war stories about KV-cache OOMs? Discovered that your serving latency is actually memory bandwidth limited and not compute bound? Built a custom memory allocator and want to brag about it? Comments below. (Paid subscribers get access to the full vLLM-compatible implementation and can discuss production deployment strategies.)
← prev: Flash Attention | next: Activation Recomputation →

