kempnerforge.resilience

Fault tolerance and resilience for KempnerForge.

class kempnerforge.resilience.NaNDetector[source]

Bases: object

Detects and tracks NaN/Inf values in loss and gradients.

Supports three responses to NaN:
  • "warn": Log a warning and continue.

  • "skip": Skip the optimizer step (zero gradients).

  • "raise": Raise a RuntimeError.

If consecutive NaN count exceeds max_consecutive, the detector signals that a checkpoint rollback is recommended.

Parameters:
  • action – What to do when NaN is detected.

  • max_consecutive – Consecutive NaN steps before recommending rollback.

  • max_history – Number of NaN step indices to retain.

__init__(action='warn', max_consecutive=5, max_history=100)[source]
Parameters:
  • action (str)

  • max_consecutive (int)

  • max_history (int)

Return type:

None

check_loss(loss, step)[source]

Check a loss value for NaN/Inf.

When running distributed, all-reduces a NaN flag so ALL ranks agree on whether to skip. Prevents rank desync where one rank sees NaN and skips its optimizer step while others proceed normally.

Parameters:
  • loss (float) – The scalar loss value to check.

  • step (int) – Current training step.

Returns:

True if the loss is valid (finite) on ALL ranks, False if any rank has NaN/Inf.

Raises:

RuntimeError – If action is “raise” and NaN is detected.

Return type:

bool

check_gradients(model, step)[source]

Check model gradients for NaN/Inf before optimizer step.

Parameters:
Returns:

True if all gradients are finite.

Return type:

bool

property should_rollback: bool

Whether consecutive NaN count suggests a checkpoint rollback.

reset()[source]

Reset NaN tracking state (e.g., after a rollback).

Return type:

None

class kempnerforge.resilience.NaNState[source]

Bases: object

Tracks NaN/Inf occurrences across training steps.

consecutive_nans: int = 0
total_nans: int = 0
last_good_loss: float = inf
last_good_step: int = 0
nan_steps: list[int]
__init__(consecutive_nans=0, total_nans=0, last_good_loss=inf, last_good_step=0, nan_steps=<factory>)
Parameters:
  • consecutive_nans (int)

  • total_nans (int)

  • last_good_loss (float)

  • last_good_step (int)

  • nan_steps (list[int])

Return type:

None

class kempnerforge.resilience.SLURMInfo[source]

Bases: object

Information about the current SLURM job.

job_id: str
job_name: str
node_list: str
num_nodes: int
ntasks_per_node: int
restart_count: int
partition: str
array_task_id: str | None
property is_requeued: bool

Whether this job has been requeued (restart_count > 0).

__init__(job_id, job_name, node_list, num_nodes, ntasks_per_node, restart_count, partition, array_task_id)
Parameters:
  • job_id (str)

  • job_name (str)

  • node_list (str)

  • num_nodes (int)

  • ntasks_per_node (int)

  • restart_count (int)

  • partition (str)

  • array_task_id (str | None)

Return type:

None

class kempnerforge.resilience.ShutdownHandler[source]

Bases: object

Cooperative shutdown handler for long-running training jobs.

Register this handler before the training loop. The training loop checks should_shutdown() after each step and takes appropriate action (save checkpoint, clean up, exit).

If the graceful shutdown exceeds timeout_sec, a forced exit is triggered via os._exit to avoid hanging on stuck collectives.

Usage:

handler = ShutdownHandler(timeout_sec=120)
handler.register()

for step in range(max_steps):
    train_step()
    if handler.should_shutdown():
        save_checkpoint()
        handler.finish()
        break
Parameters:

timeout_sec – Maximum seconds allowed for graceful shutdown before forced exit. Set to 0 to disable the timeout.

__init__(timeout_sec=600.0)[source]
Parameters:

timeout_sec (float)

Return type:

None

property shutdown_requested: bool

Whether a shutdown signal has been received.

property signal_received: Signals | None

The signal that triggered shutdown, or None.

should_shutdown()[source]

Check if the training loop should exit.

Call this after each training step.

Return type:

bool

register()[source]

Register signal handlers for SIGTERM and SIGUSR1.

Must be called from the main thread.

Return type:

None

unregister()[source]

Restore original signal handlers.

Return type:

None

finish()[source]

Call after graceful shutdown is complete.

Cancels the forced-exit timer and restores signal handlers.

Return type:

None

kempnerforge.resilience.check_gpu_health(device=0)[source]

Run basic GPU health checks.

Performs:
  1. CUDA availability check

  2. Small test computation on the device

  3. Memory allocation test

Returns:

Dict with health check results.

Parameters:

device (int)

Return type:

dict[str, bool | str]

kempnerforge.resilience.check_nccl_health(timeout_sec=10.0)[source]

Check NCCL communication health via a lightweight all-reduce.

Parameters:

timeout_sec (float) – Timeout for the collective operation.

Returns:

True if the all-reduce succeeded, False on timeout or error.

Return type:

bool

kempnerforge.resilience.get_slurm_info()[source]

Read SLURM job information from environment variables.

Returns:

SLURMInfo if running under SLURM, None otherwise.

Return type:

SLURMInfo | None

kempnerforge.resilience.is_slurm_job()[source]

Check if we are running under SLURM.

Return type:

bool

kempnerforge.resilience.is_slurm_requeue()[source]

Check if this is a requeued SLURM job.

Uses SLURM_RESTART_COUNT (set by SLURM on requeue).

Return type:

bool

kempnerforge.resilience.log_job_info()[source]

Log SLURM job information (if running under SLURM).

Return type:

None

kempnerforge.resilience.resolve_resume_path(checkpoint_dir)[source]

Find the latest checkpoint for auto-resume.

Checks:
  1. {checkpoint_dir}/latest symlink

  2. Most recent step_N directory by step number

Parameters:

checkpoint_dir (str) – Base checkpoint directory.

Returns:

Path to the latest checkpoint, or None if none found.

Return type:

Path | None

Modules

elastic

Elastic training and SLURM integration helpers.

health

GPU health monitoring and NaN detection.

signal_handler

Graceful shutdown via signal handling.