Source code for kempnerforge.resilience.elastic
"""Elastic training and SLURM integration helpers.
Provides utilities for training jobs that may be preempted, requeued,
or restarted with a different number of nodes:
- SLURM job info detection
- Requeue detection
- Auto-resume path resolution
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from pathlib import Path
logger = logging.getLogger(__name__)
[docs]
@dataclass
class SLURMInfo:
"""Information about the current SLURM job."""
job_id: str
job_name: str
node_list: str
num_nodes: int
ntasks_per_node: int
restart_count: int
partition: str
array_task_id: str | None # None if not an array job
@property
def is_requeued(self) -> bool:
"""Whether this job has been requeued (restart_count > 0)."""
return self.restart_count > 0
[docs]
def get_slurm_info() -> SLURMInfo | None:
"""Read SLURM job information from environment variables.
Returns:
SLURMInfo if running under SLURM, None otherwise.
"""
job_id = os.environ.get("SLURM_JOB_ID")
if job_id is None:
return None
return SLURMInfo(
job_id=job_id,
job_name=os.environ.get("SLURM_JOB_NAME", ""),
node_list=os.environ.get("SLURM_JOB_NODELIST", ""),
num_nodes=int(os.environ.get("SLURM_NNODES", "1")),
ntasks_per_node=int(os.environ.get("SLURM_NTASKS_PER_NODE", "1")),
restart_count=int(os.environ.get("SLURM_RESTART_COUNT", "0")),
partition=os.environ.get("SLURM_JOB_PARTITION", ""),
array_task_id=os.environ.get("SLURM_ARRAY_TASK_ID"),
)
[docs]
def is_slurm_job() -> bool:
"""Check if we are running under SLURM."""
return "SLURM_JOB_ID" in os.environ
[docs]
def is_slurm_requeue() -> bool:
"""Check if this is a requeued SLURM job.
Uses ``SLURM_RESTART_COUNT`` (set by SLURM on requeue).
"""
return int(os.environ.get("SLURM_RESTART_COUNT", "0")) > 0
[docs]
def resolve_resume_path(checkpoint_dir: str) -> Path | None:
"""Find the latest checkpoint for auto-resume.
Checks:
1. ``{checkpoint_dir}/latest`` symlink
2. Most recent ``step_N`` directory by step number
Args:
checkpoint_dir: Base checkpoint directory.
Returns:
Path to the latest checkpoint, or None if none found.
"""
base = Path(checkpoint_dir)
if not base.exists():
return None
# Check "latest" symlink first
latest = base / "latest"
if latest.exists():
resolved = latest.resolve()
if resolved.exists():
logger.info(f"Auto-resume: found latest checkpoint at {resolved}")
return resolved
# Fall back to most recent step_N directory
step_dirs = sorted(
(
d
for d in base.iterdir()
if d.is_dir() and d.name.startswith("step_") and d.name.split("_")[1].isdigit()
),
key=lambda d: int(d.name.split("_")[1]),
)
if step_dirs:
path = step_dirs[-1]
logger.info(f"Auto-resume: found checkpoint at {path}")
return path
return None
[docs]
def log_job_info() -> None:
"""Log SLURM job information (if running under SLURM)."""
info = get_slurm_info()
if info is None:
logger.info("Not running under SLURM")
return
logger.info(
f"SLURM job: id={info.job_id}, name={info.job_name}, "
f"nodes={info.num_nodes}, tasks/node={info.ntasks_per_node}, "
f"partition={info.partition}, restart_count={info.restart_count}"
)
if info.is_requeued:
logger.info(f"Job was requeued (restart #{info.restart_count}) — will auto-resume")