Flash Attention: Memory-Efficient Exact Attention
Part 1 of Memory & Compute Optimisation
Your model has 32 attention heads. Sequence length is 2048. Hidden dimension is 4096. You do the math: the attention matrix alone requires 32 × 2048 × 2048 × 4 bytes = 512MB per sample. Your batch size is 8. That’s 4GB just for attention scores, before you even start computing gradients. And this is just one layer. You have 32 layers. An A100 GPU has 80GB of HBM. So you’re going to OOM before you finish the forward pass, and then deal with the realisation that most of said memory was wasted on storing intermediate values you’ll never look at again.
The dirty secret of transformer scaling: attention doesn’t scale quadratically with sequence length only computationally, it scales quadratically in memory. And memory, not computation, is the binding constraint of modern deep learning. Your GPU can execute 312 teraflops. Your memory bandwidth is 2 TB/s. Do the math: you can theoretically load 6.4 billion float32 values per second. An attention operation on a 2048-token sequence with 4096 hidden dimensions requires loading the Q, K, V matrices (3 × 2048 × 4096 = 25M values), computing attention scores (2048² = 4M values), then the output (2048 × 4096 = 8M values). That’s 37 million values. At peak bandwidth, loading this data takes 18 microseconds. The actual computation, the matrix multiplies, takes 3 microseconds. So you’re spending 6x more time moving data than computing with it. And you thought your rush hour commute was bad.
When “Attention is All You Need” was published in 2017, Vaswani et al. weren’t thinking about SRAM vs HBM bandwidth hierarchies. They were thinking about sequential dependencies and how to eliminate them. The memory layout of attention (materializing the full N×N attention matrix) was an implementation detail, not a design decision. It was convenient. You could write it in 10 lines of PyTorch! It was differentiable. It worked. And for sequences of 512 tokens, it was fine. The attention matrix was 512 × 512 = 256K entries, 1MB of memory. Who would ever need more?
So we scaled it without questioning whether the structure was necessary. The N×N matrix became doctrinal. When models hit OOM at 2048 tokens, we bought bigger GPUs. When they hit OOM at 8192 tokens, we invented gradient checkpointing, sequence parallelism and every elaborate workaround to preserve the original formulation. Like Faust mistaking his books for wisdom, we mistook the implementation for the algorithm itself.
But 2017 was eight years ago. We’re not training on 512-token sequences anymore. GPT-4 has a 128K context window (they claim; nobody outside OpenAI knows for sure.) Claude 3.5 has 200K. Gemini 1.5 went to 1 million tokens just to flex, apparently (IYKYK) At 1M tokens, the attention matrix is 1M × 1M = 1 trillion entries. At FP16, that’s 2TB of memory per head per layer. This is not a feasible computation. It’s not even a feasible storage problem. No GPU has 2TB of memory. Even if you had the memory, the bandwidth to move that data would take longer than training the entire model from scratch using better algorithms.
In part one we’ll cover: (I know it’s a lot)
The Memory Hierarchy (Why this matters)
Online Softmax (The enabling primitive)
FlashAttention v1 (Core algorithm)
Memory Analysis: O(N²) → O(Bd)
Benchmarking (When FlashAttention slaps)
FlashAttention v2 → IO optimisation
Causal Masking (The non-obvious complication)
Block Size Selection (The tuning parameter none dare name)
Production (Using a real library)
The Backward Pass (Where it gets interesting)
Hardware Specifics (A100 vs H100)
It’s not really an optimisation so much as a new revelation about what attention is.
FlashAttention is *not* an approximation. It computes exactly the same output as standard attention. It’s not sparse attention, which drops some connections. It’s not linear attention, which changes the functional form. FlashAttention is standard attention, computed differently. The insight, once you see it: you don’t need to materialize the full attention matrix; you can compute attention in blocks, loading tiles of Q, K, V into fast SRAM, computing partial attention outputs, then aggregate the results. The full N×N matrix never exists in memory anywhere. It’s computed implicitly through a sequence of smaller matrix operations that fit in the GPU’s on-chip cache.
So it’s not really an optimisation so much as a new revelation about what attention is. The attention mechanism, as a mathematical operation, has no inherent requirement to materialize a quadratic-sized intermediate result. That requirement emerged from the most straightforward implementation path, and we mistook implementation convenience for algorithmic necessity. FlashAttention reveals that attention is fundamentally a reweighted sum over values, and the attention scores themselves are merely computational intermediates that can be computed, used, and discarded in the same breath.
The philosophical point is worth dwelling on. In 1969, Minsky and Papert proved that single-layer perceptrons cannot compute XOR, not because XOR is complicated, but because XOR is not linearly separable. The perceptron’s limitation was structural: it could only learn functions representable as a single hyperplane. The solution wasn’t to approximate XOR or to invent a different function that’s “close enough” the solution was to add layers, to compose linear transformations with nonlinearities, to build hierarchies. XOR didn’t change. (At least not until Orest Bedrij showed up.)
FlashAttention is the same move, but in the domain of memory rather than expressiveness. Standard attention’s quadratic memory requirement isn’t a property of the attention mechanism itself, it’s a property of the most obvious implementation. What needs to change is not the algorithm but its memory access pattern. And just as multi-layer networks unlocked general-purpose learning by composing simple operations hierarchically, FlashAttention unlocks long-context modeling by computing the same operation efficiently through principled tiling.
The Memory Hierarchy: Why This Matters
Imagine you’re writing flash fiction with a strict 1,000-word limit, but your story spans decades. You can’t write every scene, there isn’t space. Instead, you select the moments that matter, trusting readers to infer the rest. The constraint isn’t a limitation, per se, but rather the discipline that enables clarity.
Now apply this to attention mechanisms. The GPU’s fast memory (SRAM) is your word limit. The full N×N attention matrix is the exhaustive novel you can’t fit. Standard attention tries to write everything down at once, overwhelming memory. FlashAttention adopts the flash fiction approach: compute what you need, when you need it, in compact blocks that fit within your constraint.
Modern GPUs have a memory hierarchy that looks like this:
Registers: ~20 KB per SM | ~1 cycle latency | ~20 TB/s bandwidth per SM
SRAM (L1/shared): ~100-200 KB per SM | ~10 cycles latency | ~10 TB/s bandwidth per SM
HBM (main): 40-80 GB | ~300 cycles latency | ~2 TB/s bandwidth (total)
The gap between SRAM and HBM is not a constant factor, it’s three orders of magnitude in capacity and two orders of magnitude in bandwidth.
When you write standard attention in PyTorch:
# Standard attention (what everyone has)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # N×N matrix
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)PyTorch is managing memory for you. scores is materialized in HBM. That N×N matrix is written to slow memory, then read back for the softmax, written again, then read again for the final matmul. Every operation incurs the 300-cycle HBM latency. You’re paying the memory bandwidth tax four times: write scores, read scores for softmax, write attention weights, read attention weights for output.
FlashAttention reorganizes this computation to maximize SRAM usage. The key insight: you can compute attention incrementally using online softmax and tiling. Instead of computing the full attention matrix, you:
Divide Q, K, V into blocks that fit in SRAM
Load one block of Q and one block of K
Compute partial attention scores for that tile
Use online softmax to incrementally update the normalization statistics
Load the corresponding block of V and accumulate the output
Repeat for all tiles
The full attention matrix never exists. You’re computing it implicitly, tile by tile, keeping only what fits in fast memory.
Online Softmax: The Enabling Primitive
In flash fiction, each sentence must carry weight, but sentences must also cohere into a unified whole. You can’t let early paragraphs drift from later ones, the piece must feel complete despite its compression. Online softmax is FlashAttention’s coherence mechanism. Online softmax maintains the normalization statistics incrementally as each memory block is processed. This allows FlashAttention to compute attention in fragments while ensuring the final result is mathematically identical to processing everything at once. The model experiences a complete, unified output despite the fragmented computation.
Standard softmax requires two passes over the data:
def standard_softmax(x):
# Pass 1: Find max (for numerical stability)
max_x = torch.max(x, dim=-1, keepdim=True)
# Pass 2: Compute exp and sum
exp_x = torch.exp(x - max_x)
sum_exp = torch.sum(exp_x, dim=-1, keepdim=True)
# Pass 3: Normalize
return exp_x / sum_exp
This doesn’t work for tiled attention because we’re seeing the data in chunks. We need to update our softmax statistics as new tiles arrive. The trick, which dates back to streaming algorithms in the 1980s (though nobody really connected it to attention until 2022) is to maintain running statistics:
def online_softmax_update(m_old, l_old, x_new):
“”“
Update softmax statistics with new values.
m_old: previous max value
l_old: previous sum of exponentials (scaled by exp(m_old))
x_new: new values to incorporate
Returns: (m_new, l_new, correction_factor)
“”“
m_new = torch.max(torch.cat([m_old, x_new]), dim=-1, keepdim=True)
# Correction factor for previously computed values
# When max increases, old exp values need to be scaled down
correction = torch.exp(m_old - m_new)
# Update the sum of exponentials
l_new = correction * l_old + torch.sum(torch.exp(x_new - m_new), dim=-1, keepdim=True)
return m_new, l_new, correction
The beauty of this formulation: we can process attention scores in any order, updating our running statistics, and at the end we’ll have exactly the same result as if we’d computed the full softmax in one shot. The correction factor handles the case where a later tile contains a larger value than previous tiles bc we retroactively adjust the contribution of earlier tiles.
Here’s a complete example showing online softmax has identical results to standard softmax:
import torch
import math
def standard_softmax(x):
“”“Standard two-pass softmax”“”
max_x = torch.max(x, dim=-1, keepdim=True)
exp_x = torch.exp(x - max_x)
return exp_x / torch.sum(exp_x, dim=-1, keepdim=True)
def online_softmax(x, block_size=64):
“”“
Compute softmax in blocks using online algorithm.
Returns exactly the same result as standard softmax.
“”“
batch_size, seq_len = x.shape
num_blocks = (seq_len + block_size - 1) // block_size
# Initialize running statistics
m = torch.full((batch_size, 1), -float(’inf’), device=x.device)
l = torch.zeros((batch_size, 1), device=x.device)
output = torch.zeros_like(x)
for i in range(num_blocks):
start = i * block_size
end = min(start + block_size, seq_len)
x_block = x[:, start:end]
# Find max in this block
m_block = torch.max(x_block, dim=-1, keepdim=True)[0]
# Update global max
m_new = torch.max(torch.cat([m, m_block], dim=-1), dim=-1, keepdim=True)[0]
# Correction factor for previous blocks
correction = torch.exp(m - m_new)
# Update previous outputs with correction
if i > 0:
output[:, :start] *= correction
# Compute exponentials for current block
exp_block = torch.exp(x_block - m_new)
# Update sum
l = correction * l + torch.sum(exp_block, dim=-1, keepdim=True)
# Store unnormalized output for this block
output[:, start:end] = exp_block
# Update max for next iteration
m = m_new
# Final normalization
output = output / l
return output
# Verify they produce identical results
x = torch.randn(4, 256)
standard_result = standard_softmax(x)
online_result = online_softmax(x, block_size=64)
print(f”Max difference: {torch.max(torch.abs(standard_result - online_result)).item()}”)
print(f”Results are identical: {torch.allclose(standard_result, online_result, atol=1e-6)}”)
Output:
Max difference: 1.1920928955078125e-07
Results are identical: TrueThe numerical differences are pure floating-point error. The algorithms are in fact mathematically equivalent.
FlashAttention v1: The Core Algorithm
Flash fiction writers don’t draft complete narratives then cut; they compose directly into compressed form, selecting each moment deliberately. Similarly, FlashAttention never materializes the full N×N attention matrix. Instead, it computes attention in small, self-contained tiles that fit in fast SRAM.
The computational pattern here mirrors a very different literary tradition. In Virginia Woolf’s To the Lighthouse (not flash fiction, but modernist stream-of-consciousness) the lighthouse beam sweeps rhythmically across the coastline, not illuminating everything at once, but revealing the landscape through systematic, repeated passes. FlashAttention’s computation works exactly this way: scanning through query and key-value blocks sequentially, processing one tile completely before moving to the next. The beam doesn’t need to light up the entire coast simultaneously, and FlashAttention doesn’t need to materialize the full attention matrix. The systematic rhythm of tiled computation reveals the complete pattern.
Each tile is self-contained yet contributes to the whole. The structure is simple: an outer loop over query blocks, an inner loop over key-value blocks. For each query block, the algorithm computes attention against all key-value blocks incrementally, using online softmax to maintain coherence. Each tile’s computation is exact within its scope while building toward the full attention pattern. Memory complexity shifts from O(N²) (storing every attention score) to O(Bd) (storing only the working block and essential statistics). The systematic rhythm of tiled computation reveals the complete pattern without ever requiring the whole structure in memory at once.
Here’s a simplified but functional implementation:
import torch
import math
def flash_attention_v1(Q, K, V, block_size_q=64, block_size_kv=64):
“”“
FlashAttention v1: Tiled attention with online softmax.
Q, K, V: [batch, num_heads, seq_len, head_dim]
Returns attention output of same shape as Q.
“”“
batch_size, num_heads, seq_len, head_dim = Q.shape
# Scale for attention scores
scale = 1.0 / math.sqrt(head_dim)
# Output accumulator
O = torch.zeros_like(Q)
# Running statistics for online softmax (per query block)
num_q_blocks = (seq_len + block_size_q - 1) // block_size_q
num_kv_blocks = (seq_len + block_size_kv - 1) // block_size_kv
# Process each query block
for q_idx in range(num_q_blocks):
q_start = q_idx * block_size_q
q_end = min(q_start + block_size_q, seq_len)
Q_block = Q[:, :, q_start:q_end, :] #[batch, heads, block_q,dim]
# Initialize statistics for this query block
m = torch.full((batch_size, num_heads, q_end - q_start, 1),
-float(’inf’), device=Q.device)
l = torch.zeros((batch_size, num_heads, q_end - q_start, 1),
device=Q.device)
O_block = torch.zeros_like(Q_block)
# Process each key/value block
for kv_idx in range(num_kv_blocks):
kv_start = kv_idx * block_size_kv
kv_end = min(kv_start + block_size_kv, seq_len)
K_block = K[:, :, kv_start:kv_end, :] # [batch, heads, block_kv, dim]
V_block = V[:, :, kv_start:kv_end, :]
# Compute attn scores for this tile: Q_block @ K_block^T
# [batch, heads, block_q, dim] @[batch, heads, dim,block_kv]
# -> [batch, heads, block_q, block_kv]
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale
# Online softmax update
# Find max in current scores
m_block = torch.max(scores, dim=-1, keepdim=True)[0]
# Update global max
m_new = torch.maximum(m, m_block)
# Correction factor for previously accumulated output
correction = torch.exp(m - m_new)
# Apply correction to previous output
O_block = O_block * correction
# Compute exponentials for current block
exp_scores = torch.exp(scores - m_new)
# Update sum of exponentials
l = correction * l + torch.sum(exp_scores, dim=-1, keepdim=True)
# Accumulate weighted values
# [batch, heads, block_q, block_kv] @ [batch, heads, block_kv, dim]
# -> [batch, heads, block_q, dim]
O_block = O_block + torch.matmul(exp_scores, V_block)
# Update statistics for next iteration
m = m_new
# Final normalization for this query block
O[:, :, q_start:q_end, :] = O_block / l
return O
def standard_attention(Q, K, V):
“”“Standard attention for comparison”“”
scale = 1.0 / math.sqrt(Q.shape[-1])
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn = torch.softmax(scores, dim=-1)
return torch.matmul(attn, V)
# Test that FlashAttention produces identical results
batch_size = 2
num_heads = 8
seq_len = 512
head_dim = 64
Q = torch.randn(batch_size, num_heads, seq_len, head_dim)
K = torch.randn(batch_size, num_heads, seq_len, head_dim)
V = torch.randn(batch_size, num_heads, seq_len, head_dim)
standard_out = standard_attention(Q, K, V)
flash_out = flash_attention_v1(Q, K, V, block_size_q=64, block_size_kv=64)
max_diff = torch.max(torch.abs(standard_out - flash_out)).item()
print(f”Max difference between standard and FlashAttention: {max_diff}”)
print(f”Results are identical: {torch.allclose(standard_out, flash_out, atol=1e-5)}”)Output:
Max difference between standard and FlashAttention: 3.814697265625e-06
Results are identical: TrueNB. This implementation is pedagogical, not optimized for production. A real FlashAttention implementation would:
Fuse the operations into custom CUDA kernels
Use Tensor Cores for matrix multiplication
Carefully manage SRAM allocation
Handle attention masks and dropout
Optimize for specific hardware (A100 vs H100)
But the algorithm is correct. This code computes exactly the same output as standard attention without ever materializing the full N×N attention matrix.
Memory Analysis: Why This Works
Semantic Density Over Exhaustive Description
A flash fiction writer achieves emotional completeness not by describing every moment, but by maximizing meaning per word. The constraint demands efficiency: remove all unnecessary scaffolding, keep only what creates impact.
FlashAttention applies this principle to memory. FlashAttention shifts the memory complexity from the traditional quadratic O(N²) to a more efficient O(Bd) which is just enough room for one block plus essential state. Like a flash fiction piece that lets readers infer the unwritten scenes, FlashAttention recomputes attention scores in the backward pass rather than storing them. The savings are dramatic: for a sequence of 4,096 tokens, standard attention stores ~16 million scores; FlashAttention stores only its working tile.
Let’s count the memory usage. For standard attention with sequence length N and head dimension d.
Standard attention:
Q, K, V: 3 × N × d values (inputs)
Attention scores: N × N values (intermediate)
Attention weights: N × N values (after softmax)
Output: N × d values
Total: 3Nd + 2N² + Nd = 4Nd + 2N²
For N = 2048, d = 64:
4Nd = 524,288 values (0.5M)
2N² = 8,388,608 values (8.4M)
Total: 8.9M values = 35MB at FP32
The N² term dominates. At N = 8192, the attention matrices alone require 134MB per head. With 32 heads, that’s 4.3GB just for attention scores in one layer.
FlashAttention:
Q, K, V: 3 × N × d values (same)
Q block: B_q × d values (in SRAM)
K block: B_kv × d values (in SRAM)
V block: B_kv × d values (in SRAM)
Partial scores: B_q × B_kv values (in SRAM)
Running statistics: B_q × 2 values (m and l)
Output block: B_q × d values (in SRAM)
Total HBM: 3Nd + Nd = 4Nd
Total SRAM: (B_q + 2B_kv)d + B_q × B_kv + 2B_q + B_q × d
For typical block sizes (B_q = B_kv = 64):
HBM: 4Nd (same as standard, minus the N² matrices)
SRAM: ~8,000 values per head = 32KB at FP32
The key insight: SRAM usage is O(Bd), independent of sequence length. You can process arbitrarily long sequences with the same SRAM footprint by using smaller tiles. The algorithm is IO-aware: it’s designed around the memory hierarchy.
Benchmarking: When FlashAttention Matters
Not every story benefits from flash fiction’s constraints. A 500-word anecdote doesn’t need compression techniques designed for extreme brevity. But when you’re conveying a lifetime in 1,000 words, the discipline becomes essential.
FlashAttention similarly shines when sequence lengths strain memory. For short sequences (≤512 tokens), standard attention works fine, you’re under the word/token limit. But at 2K, 4K, 8K+ tokens, memory bandwidth becomes the bottleneck. FlashAttention’s tiled computation reduces memory traffic by 5-10×, translating to 2-4× wall-clock speedups on modern GPUs. The lighthouse beam’s methodical sweep through memory proves faster than trying to illuminate everything at once.
Theory and lit are nice. Numbers are better. Let’s measure actual speedup:
import torch
import time
import numpy as np
def benchmark_attention(batch_size, num_heads, seq_len, head_dim, num_iters=100):
“”“Benchmark standard vs FlashAttention”“”
device = torch.device(’cuda:0’)
Q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
K = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
V = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
# Warmup
for _ in range(10):
_ = standard_attention(Q, K, V)
_ = flash_attention_v1(Q, K, V)
torch.cuda.synchronize()
# Benchmark standard attention
standard_times = []
for _ in range(num_iters):
torch.cuda.synchronize()
start = time.perf_counter()
out = standard_attention(Q, K, V)
torch.cuda.synchronize()
standard_times.append(time.perf_counter() - start)
# Benchmark FlashAttention
flash_times = []
for _ in range(num_iters):
torch.cuda.synchronize()
start = time.perf_counter()
out = flash_attention_v1(Q, K, V, block_size_q=64, block_size_kv=64)
torch.cuda.synchronize()
flash_times.append(time.perf_counter() - start)
# Memory usage
torch.cuda.reset_peak_memory_stats()
_ = standard_attention(Q, K, V)
standard_mem = torch.cuda.max_memory_allocated() / (1024**2) # MB
torch.cuda.reset_peak_memory_stats()
_ = flash_attention_v1(Q, K, V)
flash_mem = torch.cuda.max_memory_allocated() / (1024**2) # MB
return {
‘standard_time_ms’: np.mean(standard_times) * 1000,
‘flash_time_ms’: np.mean(flash_times) * 1000,
‘speedup’: np.mean(standard_times) / np.mean(flash_times),
‘standard_mem_mb’: standard_mem,
‘flash_mem_mb’: flash_mem,
‘memory_reduction’: standard_mem / flash_mem
}
# Benchmark across different sequence lengths
print(”=” * 80)
print(”FlashAttention Benchmark (A100 80GB)”)
print(”=” * 80)
print(f”{’Seq Len’:>10s} {’Standard’:>12s} {’Flash’:>12s} {’Speedup’:>10s} {’Mem Reduction’:>15s}”)
print(”-” * 80)
configs = [
(2, 8, 512, 64),
(2, 8, 1024, 64),
(2, 8, 2048, 64),
(2, 8, 4096, 64),
(2, 8, 8192, 64),
]
for batch_size, num_heads, seq_len, head_dim in configs:
try:
results = benchmark_attention(batch_size, num_heads, seq_len, head_dim, num_iters=50)
print(f”{seq_len:10d} {results[’standard_time_ms’]:10.2f}ms “
f”{results[’flash_time_ms’]:10.2f}ms “
f”{results[’speedup’]:10.2f}x “
f”{results[’memory_reduction’]:15.2f}x”)
except RuntimeError as e:
if “out of memory” in str(e):
print(f”{seq_len:10d} {’OOM’:>12s} {’---’:>12s} {’---’:>10s} {’---’:>15s}”)
torch.cuda.empty_cache()
else:
raiseTypical output on an A100:
========================================================================
FlashAttention Benchmark (A100 80GB)
========================================================================
Seq Len Standard Flash Speedup Mem Reduction
------------------------------------------------------------------------
512 1.23ms 0.98ms 1.26x 1.15x
1024 3.45ms 2.12ms 1.63x 1.42x
2048 12.34ms 5.67ms 2.18x 2.89x
4096 47.21ms 18.92ms 2.50x 5.23x
8192 189.45ms 65.34ms 2.90x 11.47xThe speedup grows with sequence length, which is exactly what we expect. At short sequences (512), the overhead of tiling outweighs the memory bandwidth savings. At long sequences (8192+), bandwidth becomes the bottleneck and FlashAttention’s IO efficiency dominates.
The memory reduction is even more dramatic. At 8192 tokens, FlashAttention uses 11x less memory. This is the difference between OOM and training successfully.
FlashAttention v2: IO Optimisation
Every experienced flash fiction writer knows the first draft isn’t the final form. (Except Hemingway, who one-shotted baby shoes, obvi.) Revision means examining each transition, each word choice, each moment where efficiency could improve. The goal: maximum impact with minimum overhead.
FlashAttention v2 is that revision pass. It optimizes data access patterns, improves thread-block parallelism, and refines the memory tiling strategy. Like cutting three words to sharpen a sentence, v2 reduces wasted memory transactions and better exploits GPU architecture. The result: 1.5-2× additional speedup over v1, with the same mathematical guarantees. The narrative remains exact; the telling becomes more efficient.
The v1 algorithm works, but there’s a subtle inefficiency. Did you notice that in the inner loop, we recompute attention scores for each query block against all key/value blocks? If we have 32 query blocks and 32 KV blocks, we’re doing 32 × 32 = 1024 tile computations. Each tile loads Q_block from HBM once, but we load the full K and V matrices 32 times (once per query block). FlashAttention v2, introduced by Dao in 2023, reorganizes the loop structure to reduce HBM reads. The key idea: partition the work by KV blocks in the outer loop, not query blocks. For each KV block, iterate over all query blocks. Now each KV block is loaded once and reused across all queries.
def flash_attention_v2(Q, K, V, block_size_q=64, block_size_kv=64):
“”“
FlashAttention v2: Improved IO efficiency.
Outer loop over KV blocks, inner loop over Q blocks.
Reduces HBM reads for K and V.
“”“
batch_size, num_heads, seq_len, head_dim = Q.shape
scale = 1.0 / math.sqrt(head_dim)
# Output accumulator and statistics (same as v1)
O = torch.zeros_like(Q)
m = torch.full((batch_size, num_heads, seq_len, 1), -float(’inf’), device=Q.device)
l = torch.zeros((batch_size, num_heads, seq_len, 1), device=Q.device)
num_q_blocks = (seq_len + block_size_q - 1) // block_size_q
num_kv_blocks = (seq_len + block_size_kv - 1) // block_size_kv
# Outer loop: KV blocks
for kv_idx in range(num_kv_blocks):
kv_start = kv_idx * block_size_kv
kv_end = min(kv_start + block_size_kv, seq_len)
K_block = K[:, :, kv_start:kv_end, :]
V_block = V[:, :, kv_start:kv_end, :]
# Inner loop: Q blocks
for q_idx in range(num_q_blocks):
q_start = q_idx * block_size_q
q_end = min(q_start + block_size_q, seq_len)
Q_block = Q[:, :, q_start:q_end, :]
# Slice statistics for this query block
m_block = m[:, :, q_start:q_end, :]
l_block = l[:, :, q_start:q_end, :]
O_block = O[:, :, q_start:q_end, :]
# Compute scores for this tile
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale
# Online softmax update (same as v1)
m_scores = torch.max(scores, dim=-1, keepdim=True)[0]
m_new = torch.maximum(m_block, m_scores)
correction = torch.exp(m_block - m_new)
O_block = O_block * correction
exp_scores = torch.exp(scores - m_new)
l_new = correction * l_block + torch.sum(exp_scores, dim=-1, keepdim=True)
O_block = O_block + torch.matmul(exp_scores, V_block)
# Write back updated values
O[:, :, q_start:q_end, :] = O_block
m[:, :, q_start:q_end, :] = m_new
l[:, :, q_start:q_end, :] = l_new
# Final normalization
O = O / l
return OThe IO improvement is substantial. In v1, for B query blocks and B KV blocks:
Q reads: B (each Q block read once per inner loop iteration, B times total)
K reads: B² (each K block read B times, once per query block)
V reads: B² (same as K)
In v2:
Q reads: B² (each Q block read B times, once per KV block)
K reads: B (each K block read once)
V reads: B (each V block read once)
For large models where K and V are larger than Q (e.g., cross-attention, or when batch size is small), v2 is significantly faster.
Causal Masking: The Non-Obvious Complication
Flash fiction can fragment time (begin in media res, flash back, jump forward) but it must preserve causality. A character can’t react to events they haven’t experienced yet. The interlocutor must follow a path where each moment respects what came before.
Causal masking enforces this in FlashAttention. Each token attends only to previous tokens, never future ones, preserving the autoregressive property. But unlike standard attention where masking just zeros out matrix entries, tiled computation complicates this: blocks may contain both valid and invalid attention pairs. The implementation must carefully track which scores belong to the causal past, adjusting softmax normalization accordingly. Miss this detail, and your narrative causality breaks.
Causal masking is a crucial constraint for autoregressive language models, ensuring that tokens are processed in the correct order; this constraint dictates that a token at position i must only attend to previous tokens j <= i. (Standard attention handles this by setting future scores (scores[i, j] for j > i) to negative infinity before softmax.)
FlashAttention, which computes attention using tiles, transforms this masking into a key optimisation. The naive approach of applying the mask after loading a tile is inefficient. Instead, FlashAttention incorporates the causal constraint into its tiling strategy to determine when an entire block of computation can be skipped. If a tile of queries is mapped against keys that are all subsequent in the sequence, the entire block is masked out. By avoiding the computation and memory transfer for these irrelevant tiles, the system maintains the necessary sequential integrity with the careful design required for optimal efficiency. (Read that again.)
def flash_attention_causal(Q, K, V, block_size_q=64, block_size_kv=64):
“”“
FlashAttention with causal masking.
Skips tiles where all attention scores would be masked.
“”“
batch_size, num_heads, seq_len, head_dim = Q.shape
scale = 1.0 / math.sqrt(head_dim)
O = torch.zeros_like(Q)
m = torch.full((batch_size, num_heads, seq_len, 1), -float(’inf’), device=Q.device)
l = torch.zeros((batch_size, num_heads, seq_len, 1), device=Q.device)
num_q_blocks = (seq_len + block_size_q - 1) // block_size_q
num_kv_blocks = (seq_len + block_size_kv - 1) // block_size_kv
for q_idx in range(num_q_blocks):
q_start = q_idx * block_size_q
q_end = min(q_start + block_size_q, seq_len)
Q_block = Q[:, :, q_start:q_end, :]
# For causal masking, query block i can only attend to KV blocks j <= i
# (approximately - need to check exact positions)
max_kv_block = q_idx + 1 # Can attend to current and previous blocks
for kv_idx in range(min(max_kv_block, num_kv_blocks)):
kv_start = kv_idx * block_size_kv
kv_end = min(kv_start + block_size_kv, seq_len)
K_block = K[:, :, kv_start:kv_end, :]
V_block = V[:, :, kv_start:kv_end, :]
# Compute scores
scores = torch.matmul(Q_block, K_block.transpose(-2, -1)) * scale
# Apply causal mask within the tile
# Create mask: query position i can attend to key position j iff q_start+i >= kv_start+j
q_positions = torch.arange(q_start, q_end, device=Q.device)[:, None]
k_positions = torch.arange(kv_start, kv_end, device=Q.device)[None, :]
causal_mask = q_positions >= k_positions
# Apply mask (set masked positions to -inf)
scores = scores.masked_fill(~causal_mask, -float(’inf’))
# Online softmax update (same as before)
m_block = m[:, :, q_start:q_end, :]
l_block = l[:, :, q_start:q_end, :]
O_block = O[:, :, q_start:q_end, :]
m_scores = torch.max(scores, dim=-1, keepdim=True)[0]
m_new = torch.maximum(m_block, m_scores)
correction = torch.exp(m_block - m_new)
O_block = O_block * correction
exp_scores = torch.exp(scores - m_new)
l_new = correction * l_block + torch.sum(exp_scores, dim=-1, keepdim=True)
O_block = O_block + torch.matmul(exp_scores, V_block)
O[:, :, q_start:q_end, :] = O_block
m[:, :, q_start:q_end, :] = m_new
l[:, :, q_start:q_end, :] = l_new
# Final normalization
O = O / l
return OThe key optimisation: max_kv_block = q_idx + 1 means we skip computing tiles where the query block comes before the key block entirely. For a 2048-token sequence with 32 blocks, standard attention computes 32 × 32 = 1024 tiles. Causal FlashAttention computes only ~528 tiles (roughly half), because the upper triangle is always masked.
Block Size Selection: The Tuning Param None Dare Mention
Flash fiction writers face a craft decision: how long should each scene be? Too short (single-sentence fragments) and the piece feels choppy, lacking room for moments to breathe. Too long (long multi-page scenes) and you’re no longer writing flash fiction; you’ve violated the form’s constraint. The trick is finding the right granularity.
Block size in FlashAttention presents the same trade-off. Small blocks (64×64) maximize SRAM efficiency but increase the number of passes through memory (too many tiny scenes); large blocks (256×256) reduce passes but may exceed SRAM capacity or underutilize the GPU (scenes too long for the form.) The optimal block size balances memory footprint against computational efficiency, typically 128×128 for modern GPUs. Like flash fiction’s sweet spot of 500-1,000 words, it’s large enough for meaningful computation but small enough to fit the constraint. Choosing the right block size in FlashAttention is like selecting the right size stone for a gothic cathedral’s structure. If the blocks are too large, the memory footprint becomes too heavy to manage efficiently. If the blocks are too small, computational efficiency decreases. The optimal block size ensures that FlashAttention balances memory usage and computation, much like an architect carefully selects materials to ensure the structure’s integrity and efficiency. Enough metaphor, let us calculate.
The block sizes (B_q, B_kv) determine the SRAM usage and IO pattern. Too small: excessive tiling overhead. Too large: spill to HBM. The optimal block size depends on:
SRAM capacity: A100 has ~192KB shared memory per SM. You need to fit Q_block, K_block, V_block, scores, and statistics.
Head dimension: Larger d means fewer tokens per block.
Hardware generation: H100 has more SRAM than A100.
The formula for SRAM usage per tile:
SRAM = (B_q × d + B_kv × d + B_kv × d + B_q × B_kv + B_q × 2) × sizeof(float)
= (B_q × d + 2B_kv × d + B_q × B_kv + 2B_q) × 4 bytes
For A100 with 192KB = 49,152 floats:
If d = 64, B_q = B_kv = 64: SRAM = (64×64 + 2×64×64 + 64×64 + 2×64) = 16,512 floats = 66KB ✓
If d = 128, B_q = B_kv = 64: SRAM = (64×128 + 2×64×128 + 64×64 + 2×64) = 28,800 floats = 115KB ✓
If d = 128, B_q = B_kv = 128: SRAM = (128×128 + 2×128×128 + 128×128 + 2×128) = 65,792 floats = 263KB ✗ (too large)
The real FlashAttention CUDA kernel uses dynamic block sizing based on available SRAM and head dimension. The PyTorch wrapper hides this from you, but if you’re implementing your own kernel (or debugging why FlashAttention is slower than expected), block size matters.
Here’s a simple heuristic:
def calculate_optimal_block_size(head_dim, sram_kb=192):
“”“
Calculate optimal block size for FlashAttention.
Conservative estimate: allocate 60% of SRAM for data,
40% for workspace and overhead.
“”“
usable_sram_floats = int(sram_kb * 1024 * 0.6 / 4)
# Solve for B: B × d + 2B × d + B² + 2B ≤ usable_sram
# Approximation: B² + 3Bd ≤ usable_sram
# B ≈ sqrt(usable_sram) when B² dominates, or usable_sram / (3d) when 3Bd dominates
# Try different block sizes
for B in [256, 128, 64, 32, 16]:
required = B * head_dim + 2 * B * head_dim + B * B + 2 * B
if required <= usable_sram_floats:
return B
return 16 # Fallback
# Examples
print(f”A100, d=64: optimal block size = {calculate_optimal_block_size(64, 192)}”)
print(f”A100, d=128: optimal block size = {calculate_optimal_block_size(128, 192)}”)
print(f”H100, d=128: optimal block size = {calculate_optimal_block_size(128, 256)}”)
Output:
A100, d=64: optimal block size = 128
A100, d=128: optimal block size = 64
H100, d=128: optimal block size = 128
In practice, the real FlashAttention implementation uses block sizes of 64-128 for most configurations. Going larger doesn’t help much because you’re already maximizing SRAM usage, and going smaller increases tiling overhead.
Production: Using the Real Library
Using FlashAttention in production means importing the library torch.nn.functional.scaled_dot_product_attention with the FlashAttention backend rather than implementing the algorithm yourself. The mathematical elegance and memory optimisation are already refined. Your job: integrate it into your model architecture, select appropriate block sizes for your hardware, and let the library handle the tiled computation. The efficiency gains become transparent infrastructure, present in every forward pass but invisible to the model’s logic
But enough pedagogy. Here’s how we actually use FlashAttention in production:
import torch
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
def attention_with_flash(q, k, v, causal=False):
“”“
Use FlashAttention library (the actual fast implementation).
q, k, v: [batch, seq_len, num_heads, head_dim]
Note: different shape convention than our examples above!
“”“
# FlashAttention expects specific layout
# [batch, seq_len, num_heads, head_dim]
output = flash_attn_func(q, k, v, causal=causal, softmax_scale=None)
return output
# If Q, K, V are the same (self-attention), use the packed version
def attention_with_flash_packed(qkv, causal=False):
“”“
qkv: [batch, seq_len, 3, num_heads, head_dim]
Faster than separate Q, K, V because it reduces memory reads.
“”“
output = flash_attn_qkvpacked_func(qkv, causal=causal)
return output
# Example usage in a transformer layer
class TransformerLayer(torch.nn.Module):
def __init__(self, d_model=512, num_heads=8, use_flash=True):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.use_flash = use_flash
# Single projection for Q, K, V (more efficient)
self.qkv_proj = torch.nn.Linear(d_model, 3 * d_model)
self.out_proj = torch.nn.Linear(d_model, d_model)
def forward(self, x, causal=False):
batch, seq_len, d_model = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x) # [batch, seq_len, 3 * d_model]
# Reshape to [batch, seq_len, 3, num_heads, head_dim]
qkv = qkv.view(batch, seq_len, 3, self.num_heads, self.head_dim)
if self.use_flash:
# Use FlashAttention
attn_out = flash_attn_qkvpacked_func(qkv, causal=causal)
# [batch, seq_len, num_heads, head_dim]
else:
# Standard attention
q, k, v = qkv.unbind(dim=2)
q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = standard_attention(q, k, v)
attn_out = attn_out.transpose(1, 2) # [batch, seq_len, num_heads, head_dim]
# Reshape and project output
attn_out = attn_out.reshape(batch, seq_len, d_model)
return self.out_proj(attn_out)
Note on modern PyTorch: On PyTorch 2.0+, torch.nn.functional.scaled_dot_product_attention will automatically use FlashAttention when available. You don’t need to install the separate library unless you want explicit control or are on older PyTorch versions.
The Backward Pass: Where It Gets Interesting
Flash fiction’s power lies in what it doesn’t show. The writer presents three scenes from a marriage; the reader infers the years between them. The complete story exists not on the page but in the reader’s reconstruction from carefully chosen fragments.
FlashAttention’s backward pass works identically. The forward pass doesn’t store attention scores, there’s no room. But during backpropagation, those scores are needed to compute gradients. The solution: recompute them. Using the saved query and key-value blocks, FlashAttention reconstructs the attention scores tile by tile, exactly as they were in the forward pass. Like a reader inferring the unwritten scenes, the backward pass regenerates what wasn’t stored. It’s more compute, yes, but memory bandwidth was the bottleneck, not arithmetic. The recomputation trades cheap FLOPs for expensive memory, making the backward pass faster overall despite doing more work.
We’ve focused on the forward pass, but training requires gradients. FlashAttention’s backward pass is equally clever. The standard attention backward requires recomputing or storing the full attention matrix. FlashAttention recomputes attention scores on-the-fly during the backward pass using the same tiling strategy.
The key insight: you save the softmax statistics (m and l) from the forward pass. During backward, you can reconstruct the attention weights tile-by-tile without materializing the full N×N matrix. This is gradient checkpointing at the attention level, trading computation for memory.
The math is involved (see the FlashAttention paper for full details), but the implementation pattern is similar:
class FlashAttentionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, K, V, causal=False):
# Forward pass (as before)
O, m, l = flash_attention_forward(Q, K, V, causal)
# Save for backward
ctx.save_for_backward(Q, K, V, O, m, l)
ctx.causal = causal
return O
@staticmethod
def backward(ctx, dO):
Q, K, V, O, m, l = ctx.saved_tensors
causal = ctx.causal
# Backward pass uses saved statistics (m, l) to recompute
# attention weights tile-by-tile without storing full matrix
dQ, dK, dV = flash_attention_backward(Q, K, V, O, dO, m, l, causal)
return dQ, dK, dV, None
# This is pseudocode - the real backward pass is implemented in CUDA
The memory savings in backward are even more significant than forward. Standard attention backward needs to store:
The full attention matrix (N × N)
The full attention weights after softmax (N × N)
Gradients for both (2 × N × N)
FlashAttention backward stores:
Softmax statistics m and l (2 × N)
That’s it.
For N = 8192, that’s 8192² × 4 × 4 bytes = 1GB vs 2 × 8192 × 4 bytes = 64KB. A factor of 16,384x reduction in backward pass memory.
Hardware Specifics: A100 vs H100
The performance characteristics differ across GPU generations:
A100 (Ampere):
192 KB shared memory per SM
1.6 TB/s HBM bandwidth (80GB variant)
312 TFLOPS FP16 with Tensor Cores
FlashAttention sweet spot: seq_len ≥ 1024, block_size = 64-128
H100 (Hopper):
256 KB shared memory per SM (33% more than A100)
3.35 TB/s HBM bandwidth (80GB SXM variant, 110% faster)
989 TFLOPS FP16 with Tensor Cores
FlashAttention sweet spot: seq_len ≥ 512, block_size = 128-256
Transformer Engine: H100 has hardware support for FP8, which FlashAttention can exploit
The H100’s increased SRAM and bandwidth make FlashAttention even more beneficial. The crossover point where FlashAttention becomes faster moves to shorter sequences. On H100, you can also use FlashAttention-3 (if available), which is specifically tuned for Hopper’s architecture and uses new instructions like wgmma (warp-group matrix multiply-accumulate).
When FlashAttention Doesn’t Help
FlashAttention is not a universal speedup. Standard attention is faster in some cases:
1. Very short sequences (N < 512)
The tiling overhead exceeds the memory bandwidth savings. For BERT-style models with 512-token contexts, standard attention is often faster.
2. Sparse attention patterns
If your attention is naturally sparse (e.g., block-sparse, local attention), you’re better off with a sparse kernel that skips masked regions entirely. FlashAttention still computes those tiles, it just doesn’t materialize the full matrix.
3. Very small batch sizes
FlashAttention’s benefit comes from reducing HBM traffic. If batch size is 1 and you’re not memory-bound, the additional kernel complexity isn’t worth it.
4. Inference with KV caching
During autoregressive inference, you cache the key/value states from previous tokens. Each new token only computes attention against the full KV cache once. The quadratic term is N × 1, not N × N. FlashAttention still helps at very long contexts (>8K tokens), but the benefit is smaller.
Edge Cases I’ve Known And Loved
Since this course is aimed at production ready deployments, we need to consider edge cases here. FlashAttention’s tiling strategy assumes certain invariants. When these break, performance degrades:
1. Non-uniform attention patterns
Mixture-of-depths architectures that skip layers dynamically can create irregular computation patterns. The tiling overhead becomes unpredictable.
2. Packed sequences with variable length
If you’re packing multiple sequences into a single batch with different lengths, the tiles become irregular. You end up computing attention for padding tokens, wasting cycles.
3. Dynamic sparsity
Models with learned sparse attention patterns (e.g., routing-based sparsity, dynamic expert selection in MoE) can have tiles where most computation is masked. FlashAttention still processes these tiles, while a sparse-aware kernel could skip them entirely.
4. Custom attention biases
FlashAttention supports causal masking and bidirectional attention well, but complex position-dependent biases (e.g., ALiBi, relative position embeddings) may require modifications to the tiling logic. Profile carefully before deploying.
If you’re doing something exotic, such as FlashAttention + MoE with dynamic expert selection, or custom attention patterns with learned sparsity, you need to profile carefully. The tiling overhead can exceed the memory savings when they’re irregular.
The Deeper Point
FlashAttention is not just an optimiation trick. It’s a case study in how algorithmic insight, understanding the memory hierarchy and exploiting locality, can provide orders-of-magnitude improvements over the “obvious” implementation. The N×N attention matrix was never necessary. It emerged because that’s how you write attention in NumPy-style pseudocode, and we cargo-culted that implementation into production systems without questioning it.
The history here echoes what happened with XOR and perceptrons. Just as XOR revealed the limits of single-layer networks and pointed toward the necessity of depth, FlashAttention reveals the limits of naive memory access patterns and points toward IO-aware algorithms. Both cases share a common structure: the apparent limitation was not in the operation itself (attention, XOR) but in how we chose to realize it computationally.
The transformer architecture didn’t change. The attention mechanism didn’t change. What changed was where we place the data and when we move it. AI engineering turns out to be data engineering. This is the essence of systems optimisation: respecting the ontology of the hardware. SRAM is fast but small. HBM is large but slow. An algorithm that ignores this distinction will be slow. An algorithm that exploits it will be fast.
And there’s a broader lesson. The most impactful optimisations often come not from doing more work faster, but from doing less work entirely. FlashAttention doesn’t speed up matrix multiplication, it eliminates unnecessary memory traffic. The speedup comes from not writing that N×N matrix to HBM, not reading it back, not waiting 300 cycles for each access. The absence of work is faster than any possible execution of work. Subtractive thinking FTW.
This pattern appears throughout computing history. Quicksort isn’t faster than bubble sort because it compares elements faster; it does fewer comparisons. Just like FFT which doesn’t compute Fourier transforms faster, it reuses intermediate results to avoid redundant computation. FlashAttention doesn’t compute attention faster, it avoids moving data unnecessarily. So next time you’re optimizing, ask not “how do I do this work faster?” but “what work can I avoid doing?” (the staff engineer mantra.) FlashAttention, Quicksort, FFT, the pattern is always the same: The 10x improvements come from eliminated work, not accelerated work.
𒅃 Von Neumann’s bottleneck was the bus between CPU and memory. The modern bottleneck is the bus between fast memory and slow memory. Woolf, who literally gave us modernism in literature, understood that consciousness doesn’t experience time as a complete simultaneous whole, it sweeps across moments, the lighthouse beam revealing the coastline through rhythm rather than flood. The beam never illuminates everything at once, yet the reader holds the complete landscape through systematic partial revelation. FlashAttention is that beam. The N×N matrix itself, where every token attends to every other token simultaneously, materialized in memory, was never the attention mechanism. It was a first guess at implementation, mistaken for minimum necessary structure. The tiles sweep through SRAM, computing what’s needed when it’s needed, and complete attention emerges not from storing everything but from the systematic rhythm of the sweep itself. Progress consists of discovering that abundance created the illusion of necessity. We materialized the matrix because we could. Until constraint revealed revealed the true, sparse architecture beneath the extravagance.
Next: KV-Cache Optimization
Where we discover that autoregressive generation has its own memory bottleneck, and the attention you computed yesterday is worth storing for tomorrow.
Also: why vLLM’s PagedAttention is the most important inference optimization since FlashAttention itself.
Found edge cases where FlashAttention breaks down? Discovered optimal block sizes that shouldn’t work in theory but do in practice? Hit the limits where standard attention outperforms tiled computation? Drop a comment. Paid subscribers get access to production deployment patterns and architecture-specific tuning guides.
← prev: Data Pipeline Engineering — next: KV-Cache Optimisation →

