NaN detection¶
NaNDetector
watches for NaN / Inf in the loss every step. When it fires, the
training loop zeros gradients, skips the optimizer step, and increments
a consecutive-NaN counter. If too many consecutive NaNs pile up, the
loop stops so a human can roll back to an earlier checkpoint.
Wiring¶
# scripts/train.py
nan_detector = NaNDetector(action="warn", max_consecutive=10)
...
# Inside the training loop, after backward:
if not nan_detector.check_loss(avg_loss, step):
optimizer.zero_grad()
if nan_detector.should_rollback:
logger.error("Too many consecutive NaNs — stopping")
break
step += 1
continue
Two things to note:
action="warn"andmax_consecutive=10are hardcoded inscripts/train.py— not exposed as TOML config. If you want"skip"or"raise"behavior, edit the script or construct the detector yourself.check_gradientsis not called by the training loop. Gradient NaNs typically manifest as loss NaNs on the next step anyway; skip it unless you’re specifically debugging a gradient-explosion case.
Three actions¶
# kempnerforge/resilience/health.py — NaNDetector.__init__
if action not in ("warn", "skip", "raise"):
raise ValueError(f"Invalid NaN action: {action!r} (expected warn/skip/raise)")
Action |
Behavior |
When to pick it |
|---|---|---|
|
Log a warning, return |
Default. NaN-tolerant training where a single bad step shouldn’t kill the run. |
|
Same as |
Same as warn; the two are nearly equivalent since the caller already skips the optimizer step on |
|
Raise |
Early development when you want the run to die loudly on first NaN. |
In the shipped training loop, "warn" and "skip" produce the same
outcome (the caller already zeros grads and advances the step). The
distinction exists for callers that only call check_loss and let the
return value drive their own logic.
Cross-rank sync¶
The critical detail. On a distributed run, a NaN on one rank must stop every rank — otherwise one rank zeros grads while the others keep optimizing and FSDP gets an inconsistent view of the parameter sharding on the next step.
# kempnerforge/resilience/health.py — check_loss
local_nan = not _is_finite(loss)
if dist.is_initialized():
nan_flag = torch.tensor([1.0 if local_nan else 0.0], device="cuda")
dist.all_reduce(nan_flag)
any_nan = nan_flag.item() > 0
else:
any_nan = local_nan
Four bytes per step — dwarfed by the gradient all-reduce. The
all_reduce is a SUM (the default) — any rank with NaN lifts the
flag above zero on every rank.
If any_nan is true but local_nan is false, the log line mentions
“detected on another rank” so you can correlate which rank blew up from
per-rank logs.
State tracking¶
@dataclass
class NaNState:
consecutive_nans: int = 0 # reset on a good step
total_nans: int = 0 # monotonic across the run
last_good_loss: float = inf # last finite loss value
last_good_step: int = 0
nan_steps: list[int] = [] # capped at max_history (default 100)
consecutive_nans resets on any finite step. It’s the one that drives
rollback: when it reaches max_consecutive (default 5 in the class, 10
in the shipped config), should_rollback flips to True.
nan_steps is a diagnostic — a post-hoc “which steps actually failed”
list. Capped at 100 entries to bound memory on pathological runs.
Rollback recommendation¶
@property
def should_rollback(self) -> bool:
return self.state.consecutive_nans >= self.max_consecutive
When this trips, the training loop stops — it doesn’t roll back automatically:
# scripts/train.py
if nan_detector.should_rollback:
logger.error("Too many consecutive NaNs — stopping")
break
Rolling back is manual: resubmit with checkpoint.load_path pointing at
an earlier step_N directory. The reason it’s not automatic is that
the source of the NaN determines what’s safe:
LR spike — reduce
optimizer.lrorscheduler.warmup_steps, restart from an earlier checkpoint.Bad data — skip the offending shard, restart from the same checkpoint.
FP8 overflow — reduce
distributed.fp8_intervalor disable FP8 for sensitive layers, restart.
A rule-of-thumb: if you hit should_rollback, don’t resume from the
most recent checkpoint. It was written just before the NaN storm, so
whatever state caused the explosion is baked in.
Manual use¶
Call the detector outside the training loop for ad-hoc checks:
from kempnerforge.resilience import NaNDetector
det = NaNDetector(action="raise", max_consecutive=1) # fail fast
for step, batch in enumerate(loader):
loss = model(batch).item()
det.check_loss(loss, step) # raises on first NaN
check_gradients(model, step) does the same but walks
model.named_parameters() and returns False on the first NaN grad.
The action="raise" mode raises RuntimeError instead of returning
False (the warning case still returns False).
Limitations¶
NaN action isn’t in config. Fixed at
"warn"/max_consecutive=10inscripts/train.py:85. Change the source if you need something different.No gradient scan in the hot path.
check_gradientsexists but isn’t wired in. Add it if you’re hunting a specific gradient pathology; expect a small per-step cost (oneisfinite+.all()per parameter).Loss is already a CPU scalar.
check_lossgets a Python float, so the distributed sync creates a new tensor on CUDA and all-reduces it — a negligible one-off each step but not free. If you optimize this path, aggregate the NaN flag into the existing grad-norm all-reduce.
See also¶
SLURM preemption — the other “stop cleanly” mechanism; both rely on the training loop polling a flag between steps.
GPU health — coarser health check; run it at startup and after any NCCL failure.
Checkpointing § Auto-resume — where to point
checkpoint.load_pathwhen rolling back.