Source code for kempnerforge.resilience.health
"""GPU health monitoring and NaN detection.
Provides utilities for detecting common failures during training:
- NaN/Inf in loss or gradients
- GPU availability and basic health
- NCCL liveness via lightweight collectives
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# NaN / Inf detection
# ---------------------------------------------------------------------------
[docs]
@dataclass
class NaNState:
"""Tracks NaN/Inf occurrences across training steps."""
consecutive_nans: int = 0
total_nans: int = 0
last_good_loss: float = float("inf")
last_good_step: int = 0
nan_steps: list[int] = field(default_factory=list)
[docs]
class NaNDetector:
"""Detects and tracks NaN/Inf values in loss and gradients.
Supports three responses to NaN:
- ``"warn"``: Log a warning and continue.
- ``"skip"``: Skip the optimizer step (zero gradients).
- ``"raise"``: Raise a ``RuntimeError``.
If consecutive NaN count exceeds ``max_consecutive``, the detector
signals that a checkpoint rollback is recommended.
Args:
action: What to do when NaN is detected.
max_consecutive: Consecutive NaN steps before recommending rollback.
max_history: Number of NaN step indices to retain.
"""
[docs]
def __init__(
self,
action: str = "warn",
max_consecutive: int = 5,
max_history: int = 100,
) -> None:
if action not in ("warn", "skip", "raise"):
raise ValueError(f"Invalid NaN action: {action!r} (expected warn/skip/raise)")
self.action = action
self.max_consecutive = max_consecutive
self.max_history = max_history
self.state = NaNState()
[docs]
def check_loss(self, loss: float, step: int) -> bool:
"""Check a loss value for NaN/Inf.
When running distributed, all-reduces a NaN flag so ALL ranks agree
on whether to skip. Prevents rank desync where one rank sees NaN and
skips its optimizer step while others proceed normally.
Args:
loss: The scalar loss value to check.
step: Current training step.
Returns:
True if the loss is valid (finite) on ALL ranks, False if any rank has NaN/Inf.
Raises:
RuntimeError: If action is "raise" and NaN is detected.
"""
local_nan = not _is_finite(loss)
# Sync NaN flag across all ranks to prevent desync.
# One tiny all-reduce (4 bytes) — negligible vs gradient sync.
if dist.is_initialized():
nan_flag = torch.tensor([1.0 if local_nan else 0.0], device="cuda")
dist.all_reduce(nan_flag)
any_nan = nan_flag.item() > 0
else:
any_nan = local_nan
if not any_nan:
self.state.consecutive_nans = 0
self.state.last_good_loss = loss
self.state.last_good_step = step
return True
# NaN detected (on this rank or another)
self.state.consecutive_nans += 1
self.state.total_nans += 1
if len(self.state.nan_steps) < self.max_history:
self.state.nan_steps.append(step)
if local_nan:
msg = (
f"NaN/Inf loss at step {step} "
f"(consecutive={self.state.consecutive_nans}, "
f"total={self.state.total_nans}, "
f"last_good_loss={self.state.last_good_loss:.4f} "
f"at step {self.state.last_good_step})"
)
else:
msg = (
f"NaN/Inf detected on another rank at step {step} "
f"(consecutive={self.state.consecutive_nans}, "
f"total={self.state.total_nans})"
)
if self.action == "raise":
raise RuntimeError(msg)
elif self.action == "skip":
logger.warning(f"{msg} — skipping optimizer step")
else:
logger.warning(msg)
return False
[docs]
def check_gradients(self, model: torch.nn.Module, step: int) -> bool:
"""Check model gradients for NaN/Inf before optimizer step.
Args:
model: The model to check.
step: Current training step.
Returns:
True if all gradients are finite.
"""
for name, param in model.named_parameters():
if param.grad is not None and not torch.isfinite(param.grad).all():
msg = f"NaN/Inf gradient in {name} at step {step}"
if self.action == "raise":
raise RuntimeError(msg)
logger.warning(msg)
return False
return True
@property
def should_rollback(self) -> bool:
"""Whether consecutive NaN count suggests a checkpoint rollback."""
return self.state.consecutive_nans >= self.max_consecutive
[docs]
def reset(self) -> None:
"""Reset NaN tracking state (e.g., after a rollback)."""
self.state = NaNState()
# ---------------------------------------------------------------------------
# GPU health
# ---------------------------------------------------------------------------
[docs]
def check_gpu_health(device: int = 0) -> dict[str, bool | str]:
"""Run basic GPU health checks.
Performs:
1. CUDA availability check
2. Small test computation on the device
3. Memory allocation test
Returns:
Dict with health check results.
"""
result: dict[str, bool | str] = {
"cuda_available": torch.cuda.is_available(),
"device_accessible": False,
"compute_ok": False,
"memory_ok": False,
"error": "",
}
if not result["cuda_available"]:
result["error"] = "CUDA not available"
return result
try:
# Check device is accessible
torch.cuda.set_device(device)
result["device_accessible"] = True
# Test computation
x = torch.ones(16, device=f"cuda:{device}")
y = x + x
assert y.sum().item() == 32.0
result["compute_ok"] = True
del x, y
# Test memory allocation (1MB)
buf = torch.empty(256 * 1024, dtype=torch.float32, device=f"cuda:{device}")
del buf
result["memory_ok"] = True
except (RuntimeError, AssertionError) as e:
result["error"] = str(e)
logger.error(f"GPU health check failed on device {device}: {e}")
return result
[docs]
def check_nccl_health(timeout_sec: float = 10.0) -> bool:
"""Check NCCL communication health via a lightweight all-reduce.
Args:
timeout_sec: Timeout for the collective operation.
Returns:
True if the all-reduce succeeded, False on timeout or error.
"""
if not dist.is_initialized():
return True # No distributed, nothing to check
try:
tensor = torch.ones(1, device="cuda")
# Use a work handle with timeout
dist.all_reduce(tensor)
torch.cuda.synchronize()
expected = dist.get_world_size()
return abs(tensor.item() - expected) < 1e-5
except RuntimeError as e:
logger.error(f"NCCL health check failed: {e}")
return False
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_finite(value: float) -> bool:
"""Check if a float value is finite (not NaN/Inf)."""
import math
return math.isfinite(value)