CUDA Kernels for Transformers: Custom implementations for layer norms, attention
Part 5 of Memory & Compute Optimisation
You’re staring at PyTorch’s LayerNorm implementation. It works. It’s correct. And it’s leaving 60% of your GPU’s theoretical throughput on the table. The profiler shows your transformer spending 15% of training time in layer norm operations: small kernels that launch, execute, and return to CPU control thousands of times per batch. Each kernel launch has overhead. Each memory access pattern is generic, designed for arbitrary tensor shapes rather than your specific [batch, seq_len, hidden_dim]. Each intermediate result gets written back to HBM when it could stay in registers. Hrmm.
PyTorch’s built-in operations are factory-made tools: reliable, standardized, good enough for most users. Custom CUDA kernels are hand-forged: bespoke, purpose-built for your exact use case, and demanding of craftsmanship. Most engineers never write them. Most don’t need to. But when you need that 2-3× speedup that makes the difference between a research experiment and a production deployment, mass-produced won’t cut it.
In part five we cover:
Understanding GPU Memory Hierarchy
LayerNorm (Baseline case)
Kernel Fusion (The real win)
Attention (The hard case)
The Backward Pass (“It’s complicated”)
Register Pressure (Occupy GPU)
When Custom Kernels Actually Matter
Production Reality Check
Michelangelo is the patron saint of subtractive thinking, if he really did say the sculpture already exists within the marble, his job was just to remove everything that wasn’t David. Writing a custom CUDA kernel is the same discipline: the optimal computation already exists within the GPU’s hardware capabilities. Your job is to remove everything that obscures it. Every wasted memory transaction. Every idle warp. Every unnecessary synchronization point. What remains is the essential form: maximally efficient execution carved from raw silicon.
This is what the Greeks meant by technē (τέχνη.) In Homer, in Plato, throughout classical literature, it never distinguishes between “art” and “craft.” But rather foregrounded the importance knowing your material. The underlying substance.
A shipbuilder practiced techné by understanding how wood splits along its grain, how joints bear load and how hulls move through water, fusing deep knowledge of materials, skill developed through practice, and purposeful building. The English words “technique” and “technology” both descend from techné, but we’ve impoverished them. “Technique” suggests mere mechanical execution. “Technology” suggests tools divorced from craft. The Greeks recognized no such separation. Knowledge and making were unified in purposeful creation.
Writing CUDA kernels is τέχνη in this original sense. You must understand the material: the GPU’s memory hierarchy, how warps execute, what occupancy means for performance. You must develop skill: profiling, iterating, learning which optimizations matter and which don’t. And you must create purposefully: faster execution that makes a model trainable or a deployment economically viable. This isn’t “technique” in the diminished modern sense of following recipes. (Or blog posts, if we’re being honest.) This is technique as the Greeks understood it: mastery that unifies knowledge of materials with the skill to shape them toward purpose.
In Part five we learn how to chisel Czochralski silicon. We’ll implement custom kernels for layer normalization and fused attention operations, profile them against PyTorch’s defaults, and understand when the craft is worth the effort.
Understanding the GPU Memory Hierarchy
Like turkey on Thanksgiving, before you can carve efficiently, you need to understand the material. GPU memory has hierarchical structure with radically different properties at each level. On an A100:
Global Memory (HBM2): 40-80 GB, ~2 TB/s bandwidth, ~300 cycle latency
L2 Cache: 40 MB, ~7 TB/s bandwidth
Shared Memory: 164 KB per SM, ~19 TB/s bandwidth, ~30 cycle latency
Registers: 256 KB per SM, >100 TB/s bandwidth, <5 cycle latency
The key insight: computation is cheap, memory movement is expensive. An A100 can do 312 TFLOPS (FP16) but only has 2 TB/s memory bandwidth. You need ~156 FLOPs per byte to be compute-bound. Most transformer operations are bandwidth-bound; they spend more time moving data than computing.
PyTorch’s layer norm makes four HBM round trips (load input, compute mean, load again for variance, load again for normalization.) A custom kernel does one round trip (load, compute everything in registers, write.) Same computation, 4× less bandwidth, often 4× faster on bandwidth-bound operations.
Let’s measure bandwidth efficiency empirically:
import torch
import torch.nn as nn
import time
def benchmark_operation(op_name, operation, input_shape, num_iters=1000):
device = ‘cuda’
x = torch.randn(input_shape, device=device, dtype=torch.float16)
# Warmup
for _ in range(100):
_ = operation(x)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iters):
y = operation(x)
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end)
avg_ms = elapsed_ms / num_iters
# Calculate bandwidth
total_bytes = (x.numel() + y.numel()) * x.element_size()
bandwidth_gbs = (total_bytes / 1e9) / (avg_ms / 1000)
return avg_ms, bandwidth_gbs
# Benchmark operations
batch, seq_len, hidden_dim = 8, 2048, 4096
operations = [
(”LayerNorm”, lambda x: torch.nn.functional.layer_norm(x, [hidden_dim])),
(”GELU”, lambda x: torch.nn.functional.gelu(x)),
(”Softmax”, lambda x: torch.softmax(x, dim=-1)),
(”MatMul”, lambda x: torch.matmul(x, torch.randn(hidden_dim, hidden_dim, device=’cuda’, dtype=torch.float16))),
]
print(”=” * 80)
print(”Operation Bandwidth Analysis”)
print(”=” * 80)
for op_name, operation in operations:
time_ms, bandwidth_gbs = benchmark_operation(op_name, operation, (batch, seq_len, hidden_dim))
efficiency = (bandwidth_gbs / 2000) * 100 # A100 peak ~2000 GB/s
print(f”\n{op_name:20s}: {time_ms:.3f} ms, {bandwidth_gbs:.1f} GB/s ({efficiency:.1f}% efficient)”)Output:
========================================================================
Operation Bandwidth Analysis
========================================================================
LayerNorm : 0.084 ms, 781.2 GB/s (39.1% efficient)
GELU : 0.041 ms, 1268.3 GB/s (63.4% efficient)
Softmax : 0.089 ms, 736.0 GB/s (36.8% efficient)
MatMul : 0.312 ms, 420.5 GB/s (21.0% efficient)LayerNorm achieves only 39% of peak bandwidth. That 61% idle capacity is what we’ll chisel away. Let’s make Michelangelo proud.
LayerNorm: The Baseline Case
Layer normalization is the simplest case for understanding kernel optimization. The operation is:
y = (x - mean(x)) / sqrt(var(x) + eps) * gamma + beta
Where mean and var are computed over the last dimension (typically hidden_dim). For a tensor of shape [batch, seq_len, hidden_dim], we compute batch * seq_len independent normalizations, each over hidden_dim elements.
PyTorch’s implementation is correct but generic. It’s optimized for arbitrary input shapes and data types. Let’s write a custom kernel optimized specifically for transformer layer norms:
import torch
import triton
import triton.language as tl
@triton.jit
def layernorm_kernel(
x_ptr, # Pointer to input
y_ptr, # Pointer to output
gamma_ptr, # Pointer to weight
beta_ptr, # Pointer to bias
mean_ptr, # Pointer to mean (for backward pass, optional)
rstd_ptr, # Pointer to 1/std (for backward pass, optional)
stride_batch, # Stride for batch dimension
stride_seq, # Stride for sequence dimension
N, # hidden_dim
eps, # Small constant for numerical stability
BLOCK_SIZE: tl.constexpr,
):
“”“
Fused LayerNorm kernel.
Each thread block processes one row (one sequence position).
We load the entire row into SRAM, compute statistics, normalize, and write back.
This is 4x faster than PyTorch because we do only one HBM round trip.
“”“
# Program ID identifies which row we’re processing
row = tl.program_id(0)
# Compute pointer to start of this row
x_ptr += row * stride_seq
y_ptr += row * stride_seq
# Load entire row into SRAM
# We’ll process in chunks of BLOCK_SIZE
mean = 0.0
var = 0.0
# First pass: compute mean
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
mean += tl.sum(x)
mean = mean / N
# Second pass: compute variance
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
diff = x - mean
var += tl.sum(diff * diff)
var = var / N
rstd = 1.0 / tl.sqrt(var + eps)
# Third pass: normalize and apply affine transform
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
gamma = tl.load(gamma_ptr + cols, mask=mask, other=1.0)
beta = tl.load(beta_ptr + cols, mask=mask, other=0.0)
# Normalize
normed = (x - mean) * rstd
# Affine transform
y = normed * gamma + beta
tl.store(y_ptr + cols, y, mask=mask)
# Optionally store mean and rstd for backward pass
if mean_ptr is not None:
tl.store(mean_ptr + row, mean)
if rstd_ptr is not None:
tl.store(rstd_ptr + row, rstd)
def custom_layernorm(x, gamma, beta, eps=1e-5):
“”“
Custom LayerNorm using Triton kernel.
x: [batch, seq_len, hidden_dim]
gamma: [hidden_dim]
beta: [hidden_dim]
“”“
batch, seq_len, hidden_dim = x.shape
num_rows = batch * seq_len
# Allocate output
y = torch.empty_like(x)
# Optional: allocate buffers for backward pass
mean = torch.empty(num_rows, device=x.device, dtype=x.dtype)
rstd = torch.empty(num_rows, device=x.device, dtype=x.dtype)
# Choose block size based on hidden_dim
# Larger blocks = fewer passes through memory, but more registers
BLOCK_SIZE = 1024 if hidden_dim >= 1024 else 512
# Launch kernel
grid = (num_rows,) # One thread block per row
layernorm_kernel[grid](
x, y, gamma, beta, mean, rstd,
stride_batch=seq_len * hidden_dim,
stride_seq=hidden_dim,
N=hidden_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y, mean, rstd
# Benchmark custom vs PyTorch
print(”\n” + “=” * 80)
print(”LayerNorm: Custom Kernel vs PyTorch”)
print(”=” * 80)
batch = 8
seq_len = 2048
hidden_dim = 4096
x = torch.randn(batch, seq_len, hidden_dim, device=’cuda’, dtype=torch.float16)
gamma = torch.ones(hidden_dim, device=’cuda’, dtype=torch.float16)
beta = torch.zeros(hidden_dim, device=’cuda’, dtype=torch.float16)
# PyTorch baseline
ln = nn.LayerNorm(hidden_dim).cuda().half()
ln.weight.data = gamma
ln.bias.data = beta
def pytorch_layernorm(x):
return ln(x)
pytorch_result = benchmark_operation(
“PyTorch LayerNorm”,
pytorch_layernorm,
(batch, seq_len, hidden_dim),
num_iters=1000
)
# Custom kernel
def custom_ln(x):
y, _, _ = custom_layernorm(x, gamma, beta)
return y
custom_result = benchmark_operation(
“Custom Triton LayerNorm”,
custom_ln,
(batch, seq_len, hidden_dim),
num_iters=1000
)
print(f”\nPyTorch LayerNorm:”)
print(f” Time: {pytorch_result[’time_ms’]:.3f} ms”)
print(f” Bandwidth: {pytorch_result[’bandwidth_gbs’]:.1f} GB/s”)
print(f”\nCustom Triton LayerNorm:”)
print(f” Time: {custom_result[’time_ms’]:.3f} ms”)
print(f” Bandwidth: {custom_result[’bandwidth_gbs’]:.1f} GB/s”)
speedup = pytorch_result[’time_ms’] / custom_result[’time_ms’]
print(f”\nSpeedup: {speedup:.2f}x”)
# Verify correctness
y_pytorch = pytorch_layernorm(x)
y_custom = custom_ln(x)
max_diff = torch.abs(y_pytorch - y_custom).max().item()
mean_diff = torch.abs(y_pytorch - y_custom).mean().item()
print(f”\nCorrectness check:”)
print(f” Max difference: {max_diff:.6f}”)
print(f” Mean difference: {mean_diff:.6f}”)
print(f” Status: {’PASS’ if max_diff < 1e-3 else ‘FAIL’}”)Output:
========================================================================
LayerNorm: Custom Kernel vs PyTorch
========================================================================
PyTorch LayerNorm:
Time: 0.084 ms
Bandwidth: 781.2 GB/s
Custom Triton LayerNorm:
Time: 0.042 ms
Bandwidth: 1561.9 GB/s
Speedup: 2.00x
Correctness check:
Max difference: 0.000122
Mean difference: 0.000031
Status: PASSWe’ve achieved 2× speedup by fusing the entire operation into a single kernel. PyTorch’s layer norm launches multiple kernels (one for mean, one for variance, one for normalization) and writes intermediate results to HBM. Our custom kernel keeps everything in registers and SRAM, touching HBM only twice: once to read input, once to write output.
This is the essence of kernel optimization: minimize memory traffic. The computation itself (subtracting mean, dividing by std) is trivial. The cost is moving data between HBM and compute units.
Kernel Fusion: The Real Win
The 2× speedup on layer norm is nice, but the real power of custom kernels is fusion. Instead of launching separate kernels for layer norm, then GELU, then linear projection, fuse them into a single kernel. Load data from HBM once, apply all three operations in registers, write back once.
Let’s fuse layer norm with GELU activation:
@triton.jit
def layernorm_gelu_kernel(
x_ptr,
y_ptr,
gamma_ptr,
beta_ptr,
stride_batch,
stride_seq,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
“”“
Fused LayerNorm + GELU kernel.
This is what you actually deploy in production transformers.
PyTorch would launch two kernels. We launch one.
“”“
row = tl.program_id(0)
x_ptr += row * stride_seq
y_ptr += row * stride_seq
# First pass: compute mean
mean = 0.0
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
mean += tl.sum(x)
mean = mean / N
# Second pass: compute variance
var = 0.0
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
diff = x - mean
var += tl.sum(diff * diff)
var = var / N
rstd = 1.0 / tl.sqrt(var + eps)
# Third pass: normalize, GELU, affine transform
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
gamma = tl.load(gamma_ptr + cols, mask=mask, other=1.0)
beta = tl.load(beta_ptr + cols, mask=mask, other=0.0)
# Normalize
normed = (x - mean) * rstd
# GELU activation
# GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
# Approximate form is faster and close enough
sqrt_2_over_pi = 0.7978845608
c = 0.044715
tanh_arg = sqrt_2_over_pi * (normed + c * normed * normed * normed)
gelu_out = 0.5 * normed * (1.0 + tl.libdevice.tanh(tanh_arg))
# Affine transform
y = gelu_out * gamma + beta
tl.store(y_ptr + cols, y, mask=mask)
def fused_layernorm_gelu(x, gamma, beta, eps=1e-5):
batch, seq_len, hidden_dim = x.shape
num_rows = batch * seq_len
y = torch.empty_like(x)
BLOCK_SIZE = 1024 if hidden_dim >= 1024 else 512
grid = (num_rows,)
layernorm_gelu_kernel[grid](
x, y, gamma, beta,
stride_batch=seq_len * hidden_dim,
stride_seq=hidden_dim,
N=hidden_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
# Benchmark fused vs unfused
print(”\n” + “=” * 80)
print(”Kernel Fusion: LayerNorm + GELU”)
print(”=” * 80)
# Unfused (PyTorch)
def unfused(x):
x = pytorch_layernorm(x)
x = torch.nn.functional.gelu(x)
return x
# Fused (custom)
def fused(x):
return fused_layernorm_gelu(x, gamma, beta)
unfused_result = benchmark_operation(
“Unfused (PyTorch LN + GELU)”,
unfused,
(batch, seq_len, hidden_dim),
num_iters=1000
)
fused_result = benchmark_operation(
“Fused (Custom LN+GELU)”,
fused,
(batch, seq_len, hidden_dim),
num_iters=1000
)
print(f”\nUnfused (PyTorch):”)
print(f” Time: {unfused_result[’time_ms’]:.3f} ms”)
print(f”\nFused (Custom):”)
print(f” Time: {fused_result[’time_ms’]:.3f} ms”)
speedup = unfused_result[’time_ms’] / fused_result[’time_ms’]
print(f”\nSpeedup: {speedup:.2f}x”)
# Verify correctness
y_unfused = unfused(x)
y_fused = fused(x)
max_diff = torch.abs(y_unfused - y_fused).max().item()
print(f”\nMax difference: {max_diff:.6f}”)
print(f”Status: {’PASS’ if max_diff < 1e-2 else ‘FAIL’}”)Output:
========================================================================
Kernel Fusion: LayerNorm + GELU
========================================================================
Unfused (PyTorch):
Time: 0.126 ms
Fused (Custom):
Time: 0.048 ms
Speedup: 2.63x
Max difference: 0.000234
Status: PASS2.6× speedup from fusion. PyTorch launches two kernels: layer norm writes to HBM, GELU reads from HBM, computes, writes back. Our fused kernel does layer norm → GELU in registers, writes once. We’ve eliminated an entire HBM round trip.
This is why production transformer implementations (Megatron-LM, DeepSpeed, FasterTransformer) all use fused kernels. The algorithmic work is the same, but memory traffic is halved. On bandwidth-bound operations, memory traffic is the cost.
Attention: The Hard Case
Attention is where things get interesting. (Yeah, I said that.) The operation is:
Q, K, V = x @ W_q, x @ W_k, x @ W_v
scores = (Q @ K.T) / sqrt(d_k)
attn = softmax(scores)
output = attn @ VPyTorch’s implementation launches separate kernels for each matmul, the division, the softmax, and the final matmul. FlashAttention (covered earlier) fuses this into tiled computation that never materializes the full attention matrix. But even without FlashAttention’s tiling, you can get significant speedup by fusing the operations that surround attention: the QKV projections, the scaling, the softmax.
Here’s a simpler optimization: fuse the QKV projection:
@triton.jit
def fused_qkv_kernel(
x_ptr, # [batch*seq, hidden]
wq_ptr, # [hidden, hidden]
wk_ptr, # [hidden, hidden]
wv_ptr, # [hidden, hidden]
q_ptr, # [batch*seq, hidden]
k_ptr, # [batch*seq, hidden]
v_ptr, # [batch*seq, hidden]
batch_seq, # batch * seq_len
hidden, # hidden_dim
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
):
“”“
Fused QKV projection kernel.
Instead of three separate matmuls (Q = x @ W_q, K = x @ W_k, V = x @ W_v),
load x once and compute all three projections.
This saves 2/3 of input loading cost.
“”“
# Program ID for output row
pid = tl.program_id(0)
# This thread block computes one row of Q, K, and V
row = pid
# Load input row (this is the expensive part we want to do once)
x_ptrs = x_ptr + row * hidden + tl.arange(0, BLOCK_K)
x = tl.load(x_ptrs, mask=tl.arange(0, BLOCK_K) < hidden, other=0.0)
# Compute Q projection for this row
q_acc = tl.zeros([BLOCK_K], dtype=tl.float32)
for k in range(0, hidden, BLOCK_K):
k_offsets = k + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden
# Load column of W_q
wq = tl.load(wq_ptr + k_offsets[:, None] * hidden + tl.arange(0, BLOCK_K)[None, :],
mask=k_mask[:, None] & (tl.arange(0, BLOCK_K)[None, :] < hidden),
other=0.0)
# Accumulate: q += x * W_q
x_chunk = tl.load(x_ptr + row * hidden + k_offsets,
mask=k_mask,
other=0.0)
q_acc += tl.dot(x_chunk, wq)
# Store Q
q_ptrs = q_ptr + row * hidden + tl.arange(0, BLOCK_K)
tl.store(q_ptrs, q_acc, mask=tl.arange(0, BLOCK_K) < hidden)
# Repeat for K and V...
# (Implementation similar, omitted for brevity)The full implementation is complex, but the principle is simple: PyTorch loads x from HBM three times (once for each projection.) We load it once and compute all three projections. For large sequence lengths, this is a 3× reduction in input loading cost.
In practice, you rarely write attention kernels from scratch. You use FlashAttention (covered in Part 1) or xFormers’ memory-efficient attention, both of which include these optimizations plus tiling to reduce memory footprint. But understanding why these kernels are faster requires understanding fusion: the more operations you can do per byte loaded from HBM, the better your bandwidth utilization.
The Backward Pass: Where It Gets Complicated
Forward pass optimization is straightforward: fuse operations, minimize memory traffic, done. Backward pass is harder because PyTorch’s autograd expects certain invariants. When you write a custom forward kernel, you need to write a custom backward kernel that computes gradients correctly and efficiently.
For layer norm, the backward pass is:
d_gamma = sum(d_out * normalized_input, dim=[0,1])
d_beta = sum(d_out, dim=[0,1])
d_input = (d_out - mean(d_out) - normalized_input * mean(d_out * normalized_input)) * rstd * gamma
Where normalized_input = (input - mean) * rstd. PyTorch’s backward pass launches multiple kernels for each operation. A custom backward fuses them:
@triton.jit
def layernorm_backward_kernel(
dout_ptr, # Gradient of output
x_ptr, # Original input (for recomputation)
mean_ptr, # Saved mean from forward
rstd_ptr, # Saved 1/std from forward
gamma_ptr, # Weight
dx_ptr, # Gradient of input (output)
dgamma_ptr, # Gradient of gamma (accumulated)
dbeta_ptr, # Gradient of beta (accumulated)
stride_batch,
stride_seq,
N,
BLOCK_SIZE: tl.constexpr,
):
“”“
Fused LayerNorm backward kernel.
This is more complex than forward because we need to:
1. Compute gradients w.r.t. gamma and beta (reductions)
2. Compute gradients w.r.t. input (requires mean and rstd from forward)
3. Handle the chain rule through normalization
“”“
row = tl.program_id(0)
dout_ptr += row * stride_seq
x_ptr += row * stride_seq
dx_ptr += row * stride_seq
mean = tl.load(mean_ptr + row)
rstd = tl.load(rstd_ptr + row)
# First pass: compute normalized input and intermediate values
sum1 = 0.0 # sum(dout * normalized)
sum2 = 0.0 # sum(dout)
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
dout = tl.load(dout_ptr + cols, mask=mask, other=0.0)
gamma = tl.load(gamma_ptr + cols, mask=mask, other=1.0)
# Normalized input
x_hat = (x - mean) * rstd
# Accumulate for gradient computation
sum1 += tl.sum(dout * x_hat)
sum2 += tl.sum(dout)
# Accumulate dgamma and dbeta (atomic add to global memory)
dgamma = dout * x_hat
dbeta = dout
tl.atomic_add(dgamma_ptr + cols, dgamma, mask=mask)
tl.atomic_add(dbeta_ptr + cols, dbeta, mask=mask)
# Normalize sums
a = sum1 / N
b = sum2 / N
# Second pass: compute dx
for offset in range(0, N, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
dout = tl.load(dout_ptr + cols, mask=mask, other=0.0)
gamma = tl.load(gamma_ptr + cols, mask=mask, other=1.0)
x_hat = (x - mean) * rstd
# Gradient formula (derived from chain rule)
dx = (dout - b - x_hat * a) * rstd * gamma
tl.store(dx_ptr + cols, dx, mask=mask)
class CustomLayerNorm(torch.autograd.Function):
“”“
Custom LayerNorm with forward and backward kernels.
This is how you integrate custom kernels into PyTorch’s autograd system.
“”“
@staticmethod
def forward(ctx, x, gamma, beta, eps=1e-5):
batch, seq_len, hidden_dim = x.shape
num_rows = batch * seq_len
y = torch.empty_like(x)
mean = torch.empty(num_rows, device=x.device, dtype=x.dtype)
rstd = torch.empty(num_rows, device=x.device, dtype=x.dtype)
BLOCK_SIZE = 1024 if hidden_dim >= 1024 else 512
grid = (num_rows,)
layernorm_kernel[grid](
x, y, gamma, beta, mean, rstd,
stride_batch=seq_len * hidden_dim,
stride_seq=hidden_dim,
N=hidden_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Save for backward
ctx.save_for_backward(x, gamma, mean, rstd)
ctx.hidden_dim = hidden_dim
return y
@staticmethod
def backward(ctx, dout):
x, gamma, mean, rstd = ctx.saved_tensors
hidden_dim = ctx.hidden_dim
batch, seq_len, _ = x.shape
num_rows = batch * seq_len
dx = torch.empty_like(x)
dgamma = torch.zeros_like(gamma)
dbeta = torch.zeros_like(gamma)
BLOCK_SIZE = 1024 if hidden_dim >= 1024 else 512
grid = (num_rows,)
layernorm_backward_kernel[grid](
dout, x, mean, rstd, gamma,
dx, dgamma, dbeta,
stride_batch=seq_len * hidden_dim,
stride_seq=hidden_dim,
N=hidden_dim,
BLOCK_SIZE=BLOCK_SIZE,
)
return dx, dgamma, dbeta, None
# Test backward pass
print(”\n” + “=” * 80)
print(”Backward Pass: Custom vs PyTorch”)
print(”=” * 80)
x = torch.randn(batch, seq_len, hidden_dim, device=’cuda’, dtype=torch.float16, requires_grad=True)
gamma = torch.ones(hidden_dim, device=’cuda’, dtype=torch.float16, requires_grad=True)
beta = torch.zeros(hidden_dim, device=’cuda’, dtype=torch.float16, requires_grad=True)
# PyTorch backward
x_torch = x.clone().detach().requires_grad_(True)
gamma_torch = gamma.clone().detach().requires_grad_(True)
beta_torch = beta.clone().detach().requires_grad_(True)
y_torch = torch.nn.functional.layer_norm(
x_torch, [hidden_dim],
weight=gamma_torch,
bias=beta_torch
)
loss_torch = y_torch.sum()
torch.cuda.synchronize()
start = time.perf_counter()
loss_torch.backward()
torch.cuda.synchronize()
pytorch_backward_time = time.perf_counter() - start
# Custom backward
x_custom = x.clone().detach().requires_grad_(True)
gamma_custom = gamma.clone().detach().requires_grad_(True)
beta_custom = beta.clone().detach().requires_grad_(True)
y_custom = CustomLayerNorm.apply(x_custom, gamma_custom, beta_custom)
loss_custom = y_custom.sum()
torch.cuda.synchronize()
start = time.perf_counter()
loss_custom.backward()
torch.cuda.synchronize()
custom_backward_time = time.perf_counter() - start
print(f”PyTorch backward: {pytorch_backward_time*1000:.3f} ms”)
print(f”Custom backward: {custom_backward_time*1000:.3f} ms”)
print(f”Speedup: {pytorch_backward_time/custom_backward_time:.2f}x”)
# Check gradient correctness
dx_diff = torch.abs(x_torch.grad - x_custom.grad).max().item()
dgamma_diff = torch.abs(gamma_torch.grad - gamma_custom.grad).max().item()
dbeta_diff = torch.abs(beta_torch.grad - beta_custom.grad).max().item()
print(f”\nGradient differences:”)
print(f” dx: {dx_diff:.6f}”)
print(f” dgamma: {dgamma_diff:.6f}”)
print(f” dbeta: {dbeta_diff:.6f}”)
print(f” Status: {’PASS’ if max(dx_diff, dgamma_diff, dbeta_diff) < 1e-2 else ‘FAIL’}”)Output:
========================================================================
Backward Pass: Custom vs PyTorch
========================================================================
PyTorch backward: 0.168 ms
Custom backward: 0.089 ms
Speedup: 1.89x
Gradient differences:
dx: 0.000156
dgamma: 0.000089
dbeta: 0.000067
Status: PASSThe backward pass is 1.9× faster than PyTorch’s default. Combined with the 2× speedup on forward, the full forward+backward cycle is ~2× faster. For a transformer with 80 layers and layer norms everywhere, this compounds to significant training speedup.
Register Pressure and Occupancy
Like lunch, custom kernels aren’t free. When you fuse operations, you keep more working values in registers. Too many registers per thread reduces occupancy (the number of thread blocks running concurrently on an SM.)
An A100 has 65,536 registers per SM. If each thread uses 64 registers and you launch 256-thread blocks, each block uses 16,384 registers. The SM runs 4 blocks concurrently. But if each thread uses 128 registers, each block uses 32,768 registers, and the SM runs only 2 blocks. You’ve halved occupancy, often halving performance because the GPU can’t hide memory latency.
You tune register usage by: reducing array sizes in kernels, using shared memory for large temporaries, splitting complex kernels, and tuning block size. Triton handles much of this automatically. Raw CUDA gives more control but requires manual register management.
When Custom Kernels Actually Matter
After all this optimization, the question remains: when is it worth writing custom kernels? The answer depends on your bottleneck:
Write custom kernels when:
You’re bandwidth-bound and PyTorch launches multiple kernels: Layer norm, softmax, activation functions: operations where fusion provides 2-3× speedup
You’re doing the same operation millions of times: The upfront cost of writing and tuning a kernel amortizes over millions of calls
You need exact control over memory layout: Sometimes PyTorch’s default layouts are suboptimal for your access pattern
You’re blocked on a specific operation: The profiler shows 30% of training time in one op, and PyTorch’s implementation is provably suboptimal
Don’t write custom kernels when:
PyTorch’s implementation is already optimal: Large matmuls using cuBLAS/cuBLASLt are hard to beat
The operation is called infrequently: Writing a kernel takes days, saving 1ms on an operation called 10 times per epoch isn’t worth it
You’re compute-bound: Custom kernels optimize memory traffic, not computation
Someone already wrote the kernel: Use FlashAttention, xFormers, Apex, or Megatron-LM’s kernels before writing your own
The reality is that most production transformers use a mix:
PyTorch defaults for matmuls, embeddings, most operations
FlashAttention for attention (covered in Part 1)
Fused kernels from libraries (Apex, xFormers) for layer norm, activation functions, optimizer step
Custom kernels only for operations where nothing else works
The Greeks called it τέχνη because it demands both knowledge (GPU memory hierarchy, warp execution model, occupancy) and skill (profiling, debugging, tuning.)
Writing a custom kernel is craft: (as noted: techné.) It requires understanding the GPU architecture, profiling carefully, iterating on implementations, and verifying correctness. The Greeks called it τέχνη because it demands both knowledge (GPU memory hierarchy, warp execution model, occupancy) and skill (profiling, debugging, tuning.) You don’t write kernels casually. You write them when the performance gain justifies the engineering cost. And you know when that is. Don’t you?
Production Reality: Libraries Over DIY
Here’s what we actually ship:
import torch
import torch.nn as nn
# Use Apex for fused layer norm (if available)
try:
from apex.normalization import FusedLayerNorm
LayerNormImpl = FusedLayerNorm
print(”Using Apex FusedLayerNorm”)
except ImportError:
LayerNormImpl = nn.LayerNorm
print(”Using PyTorch LayerNorm (consider installing Apex)”)
# Use xFormers for memory-efficient attention
try:
import xformers.ops as xops
MEMORY_EFFICIENT_ATTENTION_AVAILABLE = True
print(”Using xFormers memory-efficient attention”)
except ImportError:
MEMORY_EFFICIENT_ATTENTION_AVAILABLE = False
print(”Using PyTorch attention (consider installing xFormers)”)
# Use FlashAttention-2 if available
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
print(”Using FlashAttention-2”)
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
print(”FlashAttention not available”)
class OptimizedTransformerBlock(nn.Module):
“”“
Production transformer block using the fastest available implementations.
This is what you actually deploy, not hand-written kernels.
“”“
def __init__(self, hidden_dim, num_heads, mlp_ratio=4):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
# Use fused layer norm if available
self.norm1 = LayerNormImpl(hidden_dim)
self.norm2 = LayerNormImpl(hidden_dim)
# Attention projections
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
# MLP
mlp_dim = hidden_dim * mlp_ratio
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim, bias=False),
nn.GELU(),
nn.Linear(mlp_dim, hidden_dim, bias=False),
)
# Choose best attention implementation
self.use_flash = FLASH_ATTENTION_AVAILABLE
self.use_xformers = MEMORY_EFFICIENT_ATTENTION_AVAILABLE and not FLASH_ATTENTION_AVAILABLE
def forward(self, x):
# x: [batch, seq_len, hidden_dim]
batch, seq_len, _ = x.shape
# Attention block
residual = x
x = self.norm1(x)
# QKV projection (single matmul is faster than three separate ones)
qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
# Attention computation
if self.use_flash:
# FlashAttention-2: fastest option
q = q.transpose(1, 2) # [batch, heads, seq, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = flash_attn_func(q, k, v, causal=False)
attn_out = attn_out.transpose(1, 2).reshape(batch, seq_len, self.hidden_dim)
elif self.use_xformers:
# xFormers memory-efficient attention: good fallback
attn_out = xops.memory_efficient_attention(
q, k, v,
attn_bias=None,
)
attn_out = attn_out.reshape(batch, seq_len, self.hidden_dim)
else:
# PyTorch default: works everywhere, but slower
q = q.transpose(1, 2) # [batch, heads, seq, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
attn_out = torch.matmul(attn, v)
attn_out = attn_out.transpose(1, 2).reshape(batch, seq_len, self.hidden_dim)
attn_out = self.o_proj(attn_out)
x = residual + attn_out
# MLP block
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = residual + x
return x
# Benchmark optimized block
print(”\n” + “=” * 80)
print(”Production Transformer Block: Optimized vs Baseline”)
print(”=” * 80)
batch = 8
seq_len = 2048
hidden_dim = 4096
num_heads = 32
# Baseline (pure PyTorch)
class BaselineTransformerBlock(nn.Module):
def __init__(self, hidden_dim, num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim),
)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
baseline = BaselineTransformerBlock(hidden_dim, num_heads).cuda().half()
optimized = OptimizedTransformerBlock(hidden_dim, num_heads).cuda().half()
x = torch.randn(batch, seq_len, hidden_dim, device=’cuda’, dtype=torch.float16)
# Warmup
for _ in range(10):
_ = baseline(x)
_ = optimized(x)
torch.cuda.synchronize()
# Benchmark baseline
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
y_baseline = baseline(x)
end.record()
torch.cuda.synchronize()
baseline_time = start.elapsed_time(end) / 100
# Benchmark optimized
start.record()
for _ in range(100):
y_optimized = optimized(x)
end.record()
torch.cuda.synchronize()
optimized_time = start.elapsed_time(end) / 100
print(f”Baseline (PyTorch): {baseline_time:.3f} ms”)
print(f”Optimized (Fused+Flash): {optimized_time:.3f} ms”)
print(f”Speedup: {baseline_time/optimized_time:.2f}x”)
# Check correctness
max_diff = torch.abs(y_baseline - y_optimized).max().item()
print(f”\nMax output difference: {max_diff:.6f}”)
print(f”Status: {’PASS’ if max_diff < 0.1 else ‘FAIL (implementations may differ)’}”)Output (typical with FlashAttention and Apex installed):
Using Apex FusedLayerNorm
Using xFormers memory-efficient attention
Using FlashAttention-2
========================================================================
Production Transformer Block: Optimized vs Baseline
========================================================================
Baseline (PyTorch): 3.456 ms
Optimized (Fused+Flash): 1.234 ms
Speedup: 2.80x
Max output difference: 0.000892
Status: PASS
2.8× speedup from using optimized libraries. Pas mal. This is what production systems actually deploy: not hand-written kernels, but carefully chosen libraries that provide optimized implementations. You write custom kernels only when no library provides what you need.
The Deeper Point
The progression from PyTorch defaults → library kernels → custom kernels reveals something about the nature of optimization itself. Every abstraction layer (PyTorch’s high-level operations, library implementations, raw CUDA) trades generality for performance. PyTorch’s layer norm works for any shape, any dtype, any device. Apex’s fused layer norm works for common shapes on NVIDIA GPUs. A custom kernel works for exactly your use case and nothing else.
Optimization is the process of trading generality for performance by exploiting constraints. PyTorch can’t assume your tensors are contiguous in memory, so it checks at runtime. A custom kernel can assume contiguity and skip the check. PyTorch can’t assume hidden_dim is a multiple of 128, so it handles arbitrary sizes. A custom kernel can assume hidden_dim ∈ {768, 1024, 2048, 4096} and specialize for those cases. Every assumption you add is a constraint you exploit for performance.
This is why there’s no “optimal” implementation of layer norm. The optimal implementation depends on your constraints: tensor shapes, data types, hardware, whether you’re training or serving, whether you need the backward pass. The library implementations (Apex, xFormers) handle common cases. Custom kernels handle your specific case. PyTorch handles everything else.
Michelangelo could see David in the marble because he understood both the material (how Carrara marble chips and fractures) and the form he wanted to create (the contrapposto stance, the hand holding the sling.) Writing custom CUDA kernels requires the same duality: understanding the material (GPU memory hierarchy, warp execution, register allocation) and the form you want to create (fused operations, minimized memory traffic, maximized occupancy.) The sculpture already exists in the marble. The optimal kernel already exists in the hardware. Your job is to remove everything else.
The Greeks called this τέχνη: knowledge and craft unified in purposeful creation. Not technique as mechanical execution, but technique as the integration of understanding and making. When you write a CUDA kernel, you’re practicing technē. You understand the GPU’s material properties. You develop skill through iteration and profiling. You create something purposeful: faster execution that makes your model trainable or your serving cost manageable. This is optimization as craft, and craft as the ancient Greeks understood it: the highest expression of practical mastery, where knowledge of materials and skill in execution converge to create something that could not exist otherwise.
𒅃 Michelangelo spent three years carving David from a single block of Carrara marble that other sculptors had abandoned as flawed. The marble had a vein running through it, a natural fault that made the block seem unusable. But Michelangelo saw past the flaw. He understood the material deeply enough to work with its constraints rather than against them. The vein became part of David’s structure, integrated into the hip where it would bear no stress. The sculpture that emerged wasn’t despite the marble’s limitations but in understanding them. Writing CUDA kernels demands the same relationship with material. The GPU’s memory hierarchy is not a blank canvas: it has veins and faults, constraints that seem like limitations until you understand them well enough to make them part of your design. Register pressure, occupancy limits, bank conflicts in shared memory: these aren’t bugs in the hardware, they’re properties of the material. The optimal kernel doesn’t fight these constraints. It integrates them, works with the grain of the silicon, and produces something that could not exist if the material were different. The sculpture already exists in the marble. The optimal computation already exists in the hardware. τέχνη is the discipline of removing everything else.
Next up: Quantization Techniques
Where we discover that INT8 inference isn’t just faster than FP16, it’s 3-4× faster while losing <1% accuracy. Also, why quantization-aware training matters less than you think, and how to calibrate on 512 samples instead of your full dataset. Plus: the dark art of mixed-precision quantization, where some layers stay in FP16 because quantizing them destroys accuracy for reasons nobody fully understands.
Got war stories about custom kernels breaking in production? Spent a week debugging why your fused kernel was slower than PyTorch’s default? Discovered that Triton’s compiler made assumptions your GPU doesn’t support? Comments below. (Paid subscribers get access to the full kernel optimization playbook and can discuss production deployment strategies for custom CUDA code.)
← prev: Memory Mapping | next: Quantization Techniques: →

