Source code for kempnerforge.resilience.signal_handler
"""Graceful shutdown via signal handling.
Handles SIGTERM (SLURM preemption / graceful shutdown) and SIGUSR1
(SLURM requeue) by setting a flag that the training loop checks after
each step. On signal, the loop saves an emergency checkpoint and exits.
Timeout protection ensures the process exits even if graceful shutdown
stalls (e.g., stuck in NCCL collective).
"""
from __future__ import annotations
import logging
import signal
import threading
from types import FrameType
logger = logging.getLogger(__name__)
# Signals we intercept for graceful shutdown
_SHUTDOWN_SIGNALS = (signal.SIGTERM, signal.SIGUSR1)
[docs]
class ShutdownHandler:
"""Cooperative shutdown handler for long-running training jobs.
Register this handler before the training loop. The training loop
checks ``should_shutdown()`` after each step and takes appropriate
action (save checkpoint, clean up, exit).
If the graceful shutdown exceeds ``timeout_sec``, a forced exit is
triggered via ``os._exit`` to avoid hanging on stuck collectives.
Usage::
handler = ShutdownHandler(timeout_sec=120)
handler.register()
for step in range(max_steps):
train_step()
if handler.should_shutdown():
save_checkpoint()
handler.finish()
break
Args:
timeout_sec: Maximum seconds allowed for graceful shutdown before
forced exit. Set to 0 to disable the timeout.
"""
[docs]
def __init__(self, timeout_sec: float = 600.0) -> None:
self._shutdown_requested = False
self._signal_received: signal.Signals | None = None
self._timeout_sec = timeout_sec
self._timer: threading.Timer | None = None
self._original_handlers: dict[signal.Signals, signal._HANDLER] = {}
@property
def shutdown_requested(self) -> bool:
"""Whether a shutdown signal has been received."""
return self._shutdown_requested
@property
def signal_received(self) -> signal.Signals | None:
"""The signal that triggered shutdown, or None."""
return self._signal_received
[docs]
def should_shutdown(self) -> bool:
"""Check if the training loop should exit.
Call this after each training step.
"""
return self._shutdown_requested
[docs]
def register(self) -> None:
"""Register signal handlers for SIGTERM and SIGUSR1.
Must be called from the main thread.
"""
for sig in _SHUTDOWN_SIGNALS:
self._original_handlers[sig] = signal.getsignal(sig)
signal.signal(sig, self._handle_signal)
logger.info("Shutdown handler registered (SIGTERM, SIGUSR1)")
[docs]
def unregister(self) -> None:
"""Restore original signal handlers."""
for sig, handler in self._original_handlers.items():
signal.signal(sig, handler)
self._original_handlers.clear()
self._cancel_timer()
[docs]
def finish(self) -> None:
"""Call after graceful shutdown is complete.
Cancels the forced-exit timer and restores signal handlers.
"""
self._cancel_timer()
self.unregister()
logger.info("Graceful shutdown complete")
def _handle_signal(self, signum: int, frame: FrameType | None) -> None:
"""Signal handler โ sets the shutdown flag and starts the timeout."""
sig = signal.Signals(signum)
self._shutdown_requested = True
self._signal_received = sig
logger.warning(f"Received {sig.name} โ requesting graceful shutdown")
# Start a timer for forced exit
if self._timeout_sec > 0 and self._timer is None:
self._timer = threading.Timer(self._timeout_sec, self._force_exit)
self._timer.daemon = True
self._timer.start()
logger.info(f"Forced exit in {self._timeout_sec}s if shutdown not complete")
def _force_exit(self) -> None:
"""Force-exit the process if graceful shutdown takes too long."""
import os
logger.error(f"Graceful shutdown timed out after {self._timeout_sec}s โ forcing exit")
os._exit(1)
def _cancel_timer(self) -> None:
"""Cancel the forced-exit timer if active."""
if self._timer is not None:
self._timer.cancel()
self._timer = None