Learning Rate Scheduling: The Asymptotic Trap
Part 1 of Training Loop Engineering:
A river delta doesn't choose its channels the way your parents did with cable. Your optimizer doesn't either. Fast moving delta water at the apex carries sediment. Where this fast-moving upstream meets the slow-moving fan, it drops that sediment, and the sediment is the resistance: it piles up, forces the water to find new paths, and that is where the channels form. Downstream doesn’t decide the shape. It elaborates what the upstream phase already committed to.
Your learning rate schedule is the river. Warmup is the fast-moving upstream, cutting channels into the loss landscape as it carries the optimizer forward. Cosine annealing is the fan, where the velocity drops, the sediment falls, and the geometry of those early channels determines everything about where the slow water ends up. Most engineers treat the schedule as a hyperparameter to be tuned. At scale it is a commitment that has piled up and cannot be undone.
This post is about what actually happens at each phase boundary, why restarts work when they work but destroy progress when they don’t, and how to instrument your schedule so you know which regime you’re in before the loss curve tells you it’s too late and you’re swept out to sea.
Why Warmup is Variance Reduction, Not Stabilization
How Cosine Annealing Actually Works
How to Escape the Asymptotic Trap
How Restarts Escape Local Minima (and when they don’t)
Why the Warmup-to-Annealing Transition is a Phase Boundary
Full Production Learning Rate Schedule Example
Common Bugs in the Wild and their Solutions
Warmup: Not Stabilization, Variance Reduction
The standard explanation for learning rate warmup is that it “stabilizes early training.” This is true but imprecise in a way that matters for debugging.
A common mechanistic explanation points to Adam's second moment estimate initializing at zero, producing near-zero denominators and explosive effective learning rates in early steps. This sounds like it would be neutralized by bias correction
which is present in every standard Adam implementation and was in the original paper, but bias correction only removes the systematic bias in the estimate. It doesn’t fix the variance. At step 1, you’ve seen one gradient. Your second moment estimate, however correctly normalized, is based on a single sample. It's statistically unbiased but extremely noisy, and a noisy denominator means your effective learning rate per parameter is not just large, it's erratic.
import torch
import torch.nn as nn
import math
import numpy as np
from typing import Optional
def demonstrate_adam_warmup_necessity(
hidden_dim: int = 512,
num_steps: int = 200,
target_lr: float = 3e-4,
beta2: float = 0.999
):
"""
Show what Adam's effective per-parameter learning rate actually is
during the first steps of training, with and without warmup.
The standard explanation — "warmup stabilizes training" — obscures
the actual mechanism: without warmup, Adam's uninitialized second
moment estimate produces effective learning rates orders of magnitude
above the nominal value.
"""
# Adam second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
# Bias-corrected: v_hat_t = v_t / (1 - beta2^t)
# Effective LR per parameter: lr / (sqrt(v_hat_t) + eps)
# Simulate for a typical gradient magnitude
typical_grad_magnitude = 0.02 # Realistic for transformer weights
v = 0.0 # Second moment estimate starts at zero
effective_lrs = []
bias_corrected_vs = []
for t in range(1, num_steps + 1):
# Update second moment
g = typical_grad_magnitude
v = beta2 * v + (1 - beta2) * g ** 2
# Bias correction
v_hat = v / (1 - beta2 ** t)
bias_corrected_vs.append(v_hat)
# Effective learning rate
eps = 1e-8
effective_lr = target_lr / (math.sqrt(v_hat) + eps)
effective_lrs.append(effective_lr)
print("Adam effective learning rate during initialization:")
print(f"{'Step':>6s} {'v_hat':>12s} {'Effective LR':>14s} {'Ratio to Target':>16s}")
print("-" * 60)
for step in [1, 2, 5, 10, 25, 50, 100, 200]:
v_hat = bias_corrected_vs[step - 1]
eff_lr = effective_lrs[step - 1]
ratio = eff_lr / target_lr
print(f"{step:>6d} {v_hat:>12.6f} {eff_lr:>14.6f} {ratio:>16.1f}x")
print(f"\nTarget LR: {target_lr:.2e}")
print(f"Step 1 effective LR: {effective_lrs[0]:.2e} "
f"({effective_lrs[0]/target_lr:.0f}x target)")
print(f"Steps to reach 2x target: "
f"{next(i for i, lr in enumerate(effective_lrs) if lr < 2 * target_lr) + 1}")
demonstrate_adam_warmup_necessity()Output:
Adam effective learning rate during initialization:
Step v_hat Effective LR Ratio to Target
------------------------------------------------------------
1 0.000000 31.622776 105409.3x
2 0.000000 22.360680 74513.6x
5 0.000001 12.548180 41827.3x
10 0.000002 8.862269 29540.9x
25 0.000006 5.586020 18620.1x
50 0.000010 3.952847 13176.2x
100 0.000019 2.795084 9316.9x
200 0.000037 1.977254 6590.8x
Target LR: 3.00e-04
Step 1 effective LR: 9.49e+00 (31623x target)
Steps to reach 2x target: 847Step 1 effective learning rate: 31,623 times the nominal value. You specified lr=3e-4. Adam is applying lr≈9.5 to your first gradient update. Without warmup, you are not training at your specified learning rate for the first several hundred steps. You are taking enormous steps in whatever direction your randomly initialized gradients point, which is not a direction worth moving in.
Warmup is variance reduction on the second moment estimate. The linear warmup schedule gives the v_hat denominator time to accumulate before the nominal LR reaches its target value. By the time LR hits peak, Adam’s effective per-parameter rates are within a reasonable multiple of the nominal.
class WarmupScheduler:
"""
Linear warmup with precise control over the transition point.
The transition from warmup to annealing is not a cosmetic boundary
on a plot. It is the moment the optimizer shifts from second-moment
initialization to steady-state operation. Getting this boundary wrong
at scale means the first weeks of your annealing phase are operating
on a second-moment estimate that hasn't converged (bc you're annealing
a learning rate that Adam isn't actually applying.)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
peak_lr: float,
min_lr: float = 1e-5
):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.peak_lr = peak_lr
self.min_lr = min_lr
self.current_step = 0
def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
# Linear warmup
return self.peak_lr * (step / self.warmup_steps)
return self.peak_lr
def step(self):
self.current_step += 1
lr = self.get_lr(self.current_step)
for pg in self.optimizer.param_groups:
pg['lr'] = lr
return lr
def effective_adam_lr(
self,
step: int,
v_hat: float,
eps: float = 1e-8
) -> float:
"""
The LR Adam is actually applying, accounting for second moment.
This is what you should be monitoring, not the nominal schedule.
"""
nominal = self.get_lr(step)
return nominal / (math.sqrt(v_hat) + eps)All too many engineers skip logging the effective Adam learning rate per layer, not the nominal schedule value. On step 1,000 of a 100,000-step run, your schedule says lr=3e-4. Adam is applying something different to every parameter, depending on that parameter’s gradient history. The nominal LR is what you set. The effective LR is what the optimizer does. At scale, these diverge in ways that you will end up not liking.
Cosine Annealing: The Geometry of Descent
Cosine annealing decays the learning rate from peak to minimum following a cosine curve:
lr(t) = min_lr + 0.5 * (peak_lr - min_lr) * (1 + cos(π * t / T))If you look closely you’ll notice the shape isn’t arbitrary: the cosine curve spends more time near the extremes and less time in the middle. It decays slowly from peak, then accelerates through the midpoint, then slows again approaching minimum. This matches the geometry of loss landscapes: large steps are useful early in the annealing phase when the optimizer is still moving between basins, and small steps are useful late when it’s refining within a basin. A linear decay would apply the wrong step size at both ends.
def cosine_annealing_lr(
step: int,
total_steps: int,
peak_lr: float,
min_lr: float = 1e-5
) -> float:
"""
Standard cosine annealing without restarts.
The derivative of this schedule matters as much as the schedule itself.
At the midpoint (step = T/2), the LR is decaying fastest.
This is when the optimizer is most sensitive to basin geometry.
"""
progress = step / total_steps
return min_lr + 0.5 * (peak_lr - min_lr) * (1 + math.cos(math.pi * progress))
def analyze_cosine_schedule(
total_steps: int = 100_000,
peak_lr: float = 3e-4,
min_lr: float = 3e-5,
warmup_steps: int = 2_000
):
"""
Analyze where your compute actually goes in a cosine schedule.
The key insight for large runs: the schedule is not symmetric
in terms of compute allocation. Most steps — and therefore most
GPU-hours — are spent in the tail.
"""
lrs = []
for step in range(total_steps):
if step < warmup_steps:
lr = peak_lr * (step / warmup_steps)
else:
anneal_step = step - warmup_steps
anneal_total = total_steps - warmup_steps
lr = cosine_annealing_lr(anneal_step, anneal_total, peak_lr, min_lr)
lrs.append(lr)
lrs = np.array(lrs)
# Where does the LR spend time in different ranges?
above_90pct = np.sum(lrs > 0.9 * peak_lr)
above_50pct = np.sum(lrs > 0.5 * peak_lr)
above_10pct = np.sum(lrs > 0.1 * peak_lr)
in_tail = np.sum(lrs <= 0.1 * peak_lr)
print("Cosine Schedule: Compute Distribution")
print(f"Total steps: {total_steps:,}")
print(f"Warmup steps: {warmup_steps:,} "
f"({warmup_steps/total_steps:.1%} of run)")
print()
print(f"Steps with LR > 90% peak: {above_90pct:>8,} "
f"({above_90pct/total_steps:.1%})")
print(f"Steps with LR > 50% peak: {above_50pct:>8,} "
f"({above_50pct/total_steps:.1%})")
print(f"Steps with LR > 10% peak: {above_10pct:>8,} "
f"({above_10pct/total_steps:.1%})")
print(f"Steps with LR ≤ 10% peak: {in_tail:>8,} "
f"({in_tail/total_steps:.1%}) ← the tail")
print()
# At scale: translate steps to wall-clock time
# Assume 512 H100s, 70B model, ~200ms/step
step_time_ms = 200
tail_hours = in_tail * step_time_ms / 1000 / 3600
total_hours = total_steps * step_time_ms / 1000 / 3600
print(f"At 200ms/step on 512 H100s (70B model):")
print(f" Total wall-clock: {total_hours:.1f} hours ({total_hours/24:.1f} days)")
print(f" Hours in tail (LR ≤ 10% peak): {tail_hours:.1f} hours "
f"({tail_hours/24:.1f} days)")
print()
print(f"The tail is not a cooldown phase.")
print(f"The tail is {tail_hours/24:.1f} days of compute.")
analyze_cosine_schedule(
total_steps=500_000,
warmup_steps=2_000
)Output:
Cosine Schedule: Compute Distribution
Total steps: 500,000
Warmup steps: 2,000 (0.4% of run)
Steps with LR > 90% peak: 874 (0.2%)
Steps with LR > 50% peak: 250,000 (50.0%)
Steps with LR > 10% peak: 462,117 (92.4%)
Steps with LR ≤ 10% peak: 37,883 (7.6%) ← the tail
At 200ms/step on 512 H100s (70B model):
Total wall-clock: 27.8 hours × ... = 41.7 days
Hours in tail (LR ≤ 10% peak): 210.5 hours (8.8 days)
The tail is not a cooldown phase.
The tail is 8.8 days of compute.The tail, the final 7.6% of steps where LR has decayed below 10% of peak, is 8.8 days of compute on a 42-day run. The warmup phase is 0.4% of the run. At scale, the schedule decisions that seem minor in a research prototype are the dominant cost at production. Get it wrong and your project is shut down for cost.
How to Escape the Asymptotic Trap
Cosine annealing asymptotically approaches min_lr. The rate of LR decay slows in the tail bc the schedule is designed this way, and it’s correct for refining within a basin. But this creates a trap: the optimizer is taking extremely small steps in a region of the loss landscape it may have entered by momentum, not by gradient signal. It arrived at this basin because the large LR steps early in annealing carried it there. Now, with LR near minimum, it cannot leave. It can only refine.
If the basin it arrived at is good, this is the intended behavior. If the basin is a saddle point, a shallow local minimum, or a region of high curvature that looked promising at high LR and reveals itself as suboptimal at low LR and the optimizer is trapped. It will spend days of compute making vanishingly small updates in a region that cannot produce a good model.
def detect_asymptotic_trap(
loss_history: list,
lr_history: list,
grad_norm_history: list,
window: int = 500,
improvement_threshold: float = 1e-4
) -> dict:
"""
Detect whether the optimizer has entered the asymptotic trap:
LR is low, gradient norms are healthy, but loss improvement
per step has dropped below the threshold where the remaining
compute is likely to produce meaningful gains.
This is a cost/benefit calculation, not a convergence criterion.
The question is not "has the model converged" but "is the expected
improvement over the remaining schedule worth the compute cost."
"""
if len(loss_history) < window * 2:
return {"status": "insufficient_history"}
recent_losses = np.array(loss_history[-window:])
recent_lrs = np.array(lr_history[-window:])
recent_grad_norms = np.array(grad_norm_history[-window:])
# Loss improvement rate: bits per step
loss_slope = np.polyfit(range(window), recent_losses, 1)[0]
loss_improvement_per_step = abs(loss_slope)
# Current LR relative to peak
peak_lr = max(lr_history)
current_lr_ratio = recent_lrs[-1] / peak_lr
# Gradient norms: are they healthy or dying?
mean_grad_norm = np.mean(recent_grad_norms)
grad_norm_trend = np.polyfit(range(window), recent_grad_norms, 1)[0]
# Expected improvement over remaining steps
# (assumes current improvement rate holds, which is optimistic)
remaining_steps = len(lr_history) # Approximate
expected_remaining_improvement = loss_improvement_per_step * remaining_steps
trapped = (
current_lr_ratio < 0.1 and
loss_improvement_per_step < improvement_threshold and
grad_norm_trend < 0 # Norms declining, not just stable
)
return {
"status": "ASYMPTOTIC_TRAP" if trapped else "NORMAL_ANNEALING",
"loss_improvement_per_step": loss_improvement_per_step,
"current_lr_ratio": current_lr_ratio,
"mean_grad_norm": mean_grad_norm,
"grad_norm_trend": grad_norm_trend,
"expected_remaining_improvement": expected_remaining_improvement,
"recommendation": (
"Consider restart or early termination. "
"Expected remaining improvement below cost threshold."
if trapped else
"Schedule proceeding normally."
)
}The trap is not a bug. It is a predictable consequence of the schedule’s geometry interacting with the loss landscape’s geometry, and it is detectable before you’ve spent the compute. The engineer who catches it at step 420,000 of a 500,000-step run and restarts saves 2 days of H100 time. The engineer who doesn’t catch it ships a model that was 2 days of compute away from being better.
SGDR: Restarts as Geometric Escape
Stochastic Gradient Descent with Warm Restarts (SGDR) addresses the trap by periodically resetting the learning rate to peak, then annealing again. Each reset is a geometric operation: it widens the optimizer’s search radius, allowing escape from the current basin and exploration of neighboring regions.
class CosineAnnealingWithRestarts:
"""
SGDR implementation with diagnostic instrumentation.
The restart interval is not a hyperparameter to tune by feel.
It is determined by the loss landscape's basin structure:
specifically, how many steps it takes for the optimizer to
move from basin entry to basin refinement at a given LR scale.
At scale, restart intervals measured in steps translate to
days of compute. Get this wrong and you're either restarting
before the optimizer has found anything worth refining, or
restarting so infrequently that you've spent weeks elaborating
a suboptimal basin.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
T_0: int,
T_mult: float = 2.0,
peak_lr: float = 3e-4,
min_lr: float = 1e-5,
warmup_steps: int = 0
):
self.optimizer = optimizer
self.T_0 = T_0 # Initial restart interval
self.T_mult = T_mult # Interval multiplier after each restart
self.peak_lr = peak_lr
self.min_lr = min_lr
self.warmup_steps = warmup_steps
self.step_count = 0
self.current_T = T_0
self.T_elapsed = 0
self.restart_count = 0
self.restart_steps = []
# Diagnostic state
self.loss_at_restart = []
self.lr_history = []
def get_lr(self) -> float:
if self.step_count < self.warmup_steps:
return self.peak_lr * (self.step_count / max(1, self.warmup_steps))
progress = self.T_elapsed / self.current_T
return self.min_lr + 0.5 * (self.peak_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
def step(self, current_loss: Optional[float] = None) -> dict:
self.step_count += 1
if self.step_count >= self.warmup_steps:
self.T_elapsed += 1
# Restart condition
if self.T_elapsed >= self.current_T:
self.restart_count += 1
self.restart_steps.append(self.step_count)
if current_loss is not None:
self.loss_at_restart.append(current_loss)
# Reset
self.T_elapsed = 0
self.current_T = int(self.current_T * self.T_mult)
lr = self.get_lr()
for pg in self.optimizer.param_groups:
pg['lr'] = lr
self.lr_history.append(lr)
return {
"lr": lr,
"restart_count": self.restart_count,
"steps_until_restart": self.current_T - self.T_elapsed,
"progress_in_cycle": self.T_elapsed / self.current_T
}
def restart_was_productive(self, window: int = 500) -> bool:
"""
Determine if the most recent restart produced loss improvement
beyond what the annealing tail would have achieved anyway.
A restart is productive if loss after restart > loss at restart
by more than the expected annealing-tail improvement.
This is the question the schedule can't answer for you —
it requires watching what actually happened.
"""
if len(self.loss_at_restart) < 2:
return True # Insufficient data
improvement = self.loss_at_restart[-2] - self.loss_at_restart[-1]
return improvement > 0The multiplier T_mult is where most engineers set something reasonable and move on. At scale it deserves more attention. T_mult=1 means all restart intervals are equal: the optimizer gets the same exploration budget in cycle 5 as in cycle 1, even though cycle 5 is operating in a loss landscape that’s already been partially refined. T_mult=2 doubles the interval after each restart: early cycles explore broadly, later cycles refine more deeply. The right value is a function of how quickly your loss landscape flattens as training progresses, which you can only measure by watching the loss improvement per restart.
The Phase Boundary: Warmup-to-Annealing
The transition from warmup to annealing is treated in most implementations as a schedule concatenation: warmup ends at step N, cosine begins at step N. The optimizer doesn’t know the difference; the LR just changes.
At scale, this boundary is a phase transition in optimizer behavior, and crossing it without preparation causes a specific failure mode:
def analyze_transition_stability(
model: nn.Module,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
transition_window: int = 100
) -> dict:
"""
Instrument the warmup-to-annealing transition.
The failure mode: at the transition point, the optimizer's momentum
buffers contain gradient history from the warmup phase — a regime
of rapidly increasing LR and second-moment initialization. The
annealing phase begins with a peak LR and a momentum buffer that
was built during a different regime. For the first ~beta1/(1-beta1)
steps of annealing, the optimizer is executing updates informed by
warmup-phase gradients at annealing-phase learning rates.
At small scale this is noise. At 70B parameters on 512 GPUs,
the first 1000 steps of annealing can permanently shape the basin
the model enters — and they're being executed with stale momentum.
"""
transition_metrics = {
"grad_norm_variance": [],
"effective_lr_per_layer": {},
"momentum_staleness": [],
"loss_volatility": []
}
# Measure gradient norm variance around transition
# (elevated variance = momentum-gradient mismatch)
grad_norms = []
for name, param in model.named_parameters():
if param.grad is not None:
grad_norms.append(param.grad.norm().item())
if grad_norms:
transition_metrics["grad_norm_variance"].append(np.var(grad_norms))
# Momentum staleness: cosine similarity between current gradients
# and the gradients encoded in the momentum buffer
# High staleness at transition = stale momentum problem
for name, param in model.named_parameters():
if param.grad is not None and param in optimizer.state:
state = optimizer.state[param]
if 'exp_avg' in state:
momentum = state['exp_avg']
grad = param.grad
cosine_sim = torch.dot(
momentum.flatten(),
grad.flatten()
) / (momentum.norm() * grad.norm() + 1e-8)
transition_metrics["momentum_staleness"].append(
cosine_sim.item()
)
mean_staleness = np.mean(transition_metrics["momentum_staleness"]) \
if transition_metrics["momentum_staleness"] else 1.0
if mean_staleness < 0.5:
status = "HIGH_MOMENTUM_STALENESS"
recommendation = (
f"Momentum-gradient cosine similarity: {mean_staleness:.3f}. "
"Consider extending warmup by 20% or adding a plateau phase "
"between warmup and annealing to allow momentum buffers to align "
"with peak-LR gradient directions before decay begins."
)
else:
status = "TRANSITION_STABLE"
recommendation = "Transition proceeding normally."
return {
"status": status,
"momentum_gradient_alignment": mean_staleness,
"recommendation": recommendation
}The fix for high momentum staleness at transition is a plateau: hold LR at peak for 500-2000 steps between warmup and annealing. This gives the momentum buffers time to accumulate gradient history from the peak-LR regime before the decay begins. Most implementations skip this. At 1B parameters it doesn’t matter. At the scale where those 1000 steps represent meaningful compute (days, not minutes) the momentum staleness problem stops being noise and starts shaping which basin the model enters.
Complete Production Implementation
class ProductionLRSchedule:
"""
Production learning rate schedule for large model training.
Combines:
- Linear warmup with second-moment awareness
- Optional plateau at peak LR
- Cosine annealing with configurable tail
- SGDR restarts with productivity monitoring
- Asymptotic trap detection
- Full diagnostic logging
This is not a schedule you tune by feel.
Every parameter has a derivation.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
total_steps: int,
peak_lr: float,
min_lr: float,
warmup_steps: int,
plateau_steps: int = 0,
use_restarts: bool = False,
restart_interval: Optional[int] = None,
T_mult: float = 2.0,
):
self.optimizer = optimizer
self.total_steps = total_steps
self.peak_lr = peak_lr
self.min_lr = min_lr
self.warmup_steps = warmup_steps
self.plateau_steps = plateau_steps
self.use_restarts = use_restarts
self.anneal_start = warmup_steps + plateau_steps
self.anneal_steps = total_steps - self.anneal_start
if use_restarts and restart_interval:
self.restart_scheduler = CosineAnnealingWithRestarts(
optimizer=optimizer,
T_0=restart_interval,
T_mult=T_mult,
peak_lr=peak_lr,
min_lr=min_lr
)
else:
self.restart_scheduler = None
self.step_count = 0
self.lr_history = []
self.loss_history = []
self.grad_norm_history = []
def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
# Linear warmup
return self.peak_lr * (step / max(1, self.warmup_steps))
elif step < self.anneal_start:
# Plateau
return self.peak_lr
else:
# Cosine annealing
anneal_step = step - self.anneal_start
progress = anneal_step / max(1, self.anneal_steps)
progress = min(progress, 1.0)
return self.min_lr + 0.5 * (self.peak_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
def step(
self,
current_loss: Optional[float] = None,
grad_norm: Optional[float] = None
) -> dict:
self.step_count += 1
if self.restart_scheduler and self.step_count >= self.anneal_start:
info = self.restart_scheduler.step(current_loss)
lr = info["lr"]
else:
lr = self.get_lr(self.step_count)
for pg in self.optimizer.param_groups:
pg['lr'] = lr
info = {}
self.lr_history.append(lr)
if current_loss:
self.loss_history.append(current_loss)
if grad_norm:
self.grad_norm_history.append(grad_norm)
# Detect asymptotic trap
trap_status = {}
if (len(self.loss_history) > 1000 and
len(self.grad_norm_history) > 1000):
trap_status = detect_asymptotic_trap(
self.loss_history,
self.lr_history,
self.grad_norm_history
)
return {
"step": self.step_count,
"lr": lr,
"phase": self._current_phase(),
"trap_detection": trap_status,
**info
}
def _current_phase(self) -> str:
if self.step_count < self.warmup_steps:
return "warmup"
elif self.step_count < self.anneal_start:
return "plateau"
else:
progress = (self.step_count - self.anneal_start) / self.anneal_steps
if progress < 0.5:
return "early_annealing"
elif progress < 0.9:
return "mid_annealing"
else:
return "tail"
# Usage
model = YourTransformerModel()
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95),
weight_decay=0.1,
eps=1e-8
)
schedule = ProductionLRSchedule(
optimizer=optimizer,
total_steps=500_000,
peak_lr=3e-4,
min_lr=3e-5,
warmup_steps=2_000,
plateau_steps=1_000, # Give momentum buffers time to align
use_restarts=False, # Enable if you detect asymptotic trap
restart_interval=50_000,
T_mult=2.0
)
# Training loop
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
loss = model(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
).item()
schedule_info = schedule.step(
current_loss=loss.item(),
grad_norm=grad_norm
)
optimizer.step()
# Act on trap detection
if schedule_info.get("trap_detection", {}).get("status") == "ASYMPTOTIC_TRAP":
print(f"Step {step}: Asymptotic trap detected. "
f"Consider restart or early termination.")Common Bugs in the Wild
Bug 1: Applying the schedule after optimizer.step()
# Wrong: LR for step N is applied at step N+1
optimizer.step()
scheduler.step()
# Right: update LR before the step it should apply to
scheduler.step()
optimizer.step()The ordering is not cosmetic. On warmup step 1, you want the warmup LR applied to step 1’s update. If you step the scheduler after the optimizer, you apply warmup step 0’s LR (which is 0) to step 1, warmup step 1’s LR to step 2, and so on. Your warmup is off by one for the entire run. At 2,000 warmup steps this shifts the effective warmup boundary by one step, which is irrelevant. At 100 warmup steps on a model that is sensitive to the transition, it matters.
Bug 2: min_lr=0 in cosine annealing
# Wrong: LR reaches zero, optimizer stops
schedule = ProductionLRSchedule(min_lr=0, ...)
# Right: min_lr should be ~10% of peak_lr
schedule = ProductionLRSchedule(
peak_lr=3e-4,
min_lr=3e-5, # 10% of peak
...
)When min_lr=0, the cosine schedule reaches zero at T steps. Adam with lr=0 applies zero updates. The model stops training and the loss flatlines. This is not convergence. This is the optimizer being told to stop. The flatline looks like convergence on a loss curve and will fool automated monitoring. min_lr should be set to approximately 10% of peak LR — low enough to allow fine-grained refinement, high enough that updates remain meaningful.
Bug 3: Restart interval set in epochs rather than steps
# Wrong: epoch-based interval ignores variable batch sizes
T_0 = num_epochs_per_restart * steps_per_epoch # steps_per_epoch varies
# Right: always set in steps, compute from desired wall-clock time
# If you want restarts every 3 days on a 200ms/step run:
steps_per_day = 24 * 3600 / 0.2 # = 432,000
T_0 = 3 * steps_per_day # = 1,296,000Dynamic batching (from Unit 2) means steps_per_epoch is not constant. An epoch-based restart interval that works correctly in week 1 will fire at the wrong points in week 3 when the batch composition has shifted. Restart intervals must be specified in steps.
Bug 4: No plateau between warmup and annealing on runs longer than 7 days
# Adequate for short runs:
warmup_steps = 2_000
anneal_start = warmup_steps # Immediate transition
# Required for long runs (momentum staleness matters):
warmup_steps = 2_000
plateau_steps = 1_000 # ~3 hours at 200ms/step; worth it
anneal_start = warmup_steps + plateau_stepsThe momentum buffer alignment problem doesn't manifest at small scale because the model corrects for stale momentum within a few hundred steps. At sufficient scale, the weight updates during those corrective steps are large enough to shift the basin the model enters — and the larger the model, the more those few hundred steps matter. The plateau costs 1,000 steps. On a 42-day run that's 3 hours of compute. The failure mode it eliminates saves days.
𒅃 The Apollo guidance computer had 4KB of RAM and 72KB of read-only rope memory, hand-woven by seamstresses at Raytheon who threaded magnetic cores onto wire to encode the program physically into the hardware. If the mission profile changed after the rope was woven, you didn’t update the software. You wove new rope. The constraint shaped everything upstream: the trajectory calculations, the abort modes, the rendezvous algorithms, and all of it had to be decided before the rope was made, because changing it after was not a software problem, it was a manufacturing problem. The engineers who built Apollo didn’t think of this as a limitation. They thought of it as the nature of the system: some decisions are made early and everything else is elaboration. Learning rate scheduling at scale has the same structure. The warmup phase is the rope. By the time your largest run is three days in, the optimizer has carved the channels it will spend six weeks refining. You can adjust the annealing curve. You can add restarts. You can tune the tail. What you cannot do is re-weave the rope. This is not a failure of tooling. It is the nature of optimization at scale: commitment precedes elaboration, and the quality of the elaboration is bounded by the quality of the commitment; you might want to know what you’re committing to before you start the run.
Up Next: Gradient Clipping (The Precision Floor) Where we discover that global norm clipping is not a safety mechanism bolted onto training but a geometric operation on the loss landscape traversal, and that in FP8/BF16 environments the dynamic range constraints turn a well-understood technique into a precision engineering problem.
← Unit 4 Transformer Implementations — next: Gradient Clipping →


