"""Checkpoint manager for distributed checkpointing.
Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state,
which supports automatic resharding (save with N GPUs, load with M GPUs).
Non-distributed state (scheduler, dataloader, training meta, RNG) is saved
separately as a torch file and broadcast from rank 0 on load.
"""
from __future__ import annotations
import json
import logging
import shutil
from pathlib import Path
from typing import Any
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from kempnerforge.checkpoint.async_save import AsyncCheckpointer
from kempnerforge.checkpoint.state import build_train_state, restore_train_state
from kempnerforge.config.schema import CheckpointConfig
logger = logging.getLogger(__name__)
# Filename for non-distributed training state within a checkpoint directory
_TRAIN_STATE_FILE = "train_state.pt"
_METADATA_FILE = "metadata.json"
[docs]
class CheckpointManager:
"""Manages save/load/cleanup of distributed checkpoints.
Each checkpoint is stored in a subdirectory: ``{dir}/step_{N}/``
containing DCP shards and a non-distributed training state file.
A ``latest`` symlink always points to the most recent checkpoint.
Args:
config: Checkpoint configuration.
model: The model (FSDP-wrapped or plain).
optimizer: The optimizer.
"""
[docs]
def __init__(
self,
config: CheckpointConfig,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
process_group=None,
pp_rank: int | None = None,
) -> None:
self.config = config
self.model = model
self.optimizer = optimizer
self.base_dir = Path(config.dir)
self._rank = dist.get_rank() if dist.is_initialized() else 0
self._async_ckpt = AsyncCheckpointer(mode=config.async_mode)
self._process_group = process_group
self._pp_rank = pp_rank
# Dataloader state stashed during load() when the caller cannot yet
# provide a dataloader object. Applied later via
# apply_dataloader_state() once the loader is constructed.
self._pending_dataloader_state: dict[str, Any] | None = None
def _checkpoint_dir(self, step: int) -> Path:
return self.base_dir / f"step_{step}"
def _latest_link(self) -> Path:
return self.base_dir / "latest"
[docs]
def save(
self,
step: int,
tokens_seen: int = 0,
scheduler: Any | None = None,
dataloader: Any | None = None,
extra: dict | None = None,
) -> None:
"""Save a checkpoint at the given step.
Args:
step: Current training step.
tokens_seen: Total tokens processed.
scheduler: LR scheduler to save.
dataloader: Stateful dataloader to save.
extra: Additional metadata.
"""
ckpt_dir = self._checkpoint_dir(step)
# Create directory (all ranks)
ckpt_dir.mkdir(parents=True, exist_ok=True)
# With PP, each stage has different parameters — save DCP shards to
# a per-stage subdirectory to avoid .metadata file collisions.
dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir
dcp_dir.mkdir(parents=True, exist_ok=True)
# Save distributed state (model + optimizer) via DCP
dcp_state = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
self._async_ckpt.save(
dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group
)
# Save non-distributed state (rank 0 only)
if self._rank == 0:
train_state = build_train_state(
step=step,
tokens_seen=tokens_seen,
scheduler=scheduler,
dataloader=dataloader,
extra=extra,
)
torch.save(train_state, ckpt_dir / _TRAIN_STATE_FILE)
# Write human-readable metadata
meta = {"step": step, "tokens_seen": tokens_seen}
(ckpt_dir / _METADATA_FILE).write_text(json.dumps(meta, indent=2))
# Update "latest" symlink
latest = self._latest_link()
tmp_link = latest.with_suffix(".tmp")
tmp_link.unlink(missing_ok=True)
tmp_link.symlink_to(ckpt_dir.name)
tmp_link.rename(latest)
logger.info(f"Checkpoint saved: {ckpt_dir} (step={step})")
# Cleanup old checkpoints
self._cleanup()
# save() is a collective: non-rank-0 ranks must not return until
# rank-0 has committed train_state.pt, metadata.json, and the
# latest symlink. Without this barrier, post-save hooks or readers
# on other ranks race rank-0's writes (especially on NFS/Lustre).
if dist.is_initialized():
dist.barrier()
[docs]
def wait(self) -> None:
"""Block until any pending async checkpoint save completes."""
self._async_ckpt.wait()
[docs]
def load(
self,
path: str | None = None,
scheduler: Any | None = None,
dataloader: Any | None = None,
exclude_keys: list[str] | None = None,
) -> tuple[int, int, dict[str, Any]]:
"""Load a checkpoint and restore all state.
Args:
path: Checkpoint path. If None, loads from ``config.load_path``
or the ``latest`` symlink.
scheduler: LR scheduler to restore.
dataloader: Stateful dataloader to restore.
exclude_keys: DCP state keys to skip (e.g., ["optimizer"] for fine-tuning).
Returns:
Tuple of (step, tokens_seen, extra) where extra contains any
additional keys saved via ``build_train_state(extra=...)``.
"""
ckpt_dir = self._resolve_load_path(path)
if ckpt_dir is None:
logger.info("No checkpoint found — starting from scratch")
return 0, 0, {}
logger.info(f"Loading checkpoint: {ckpt_dir}")
# Load distributed state via DCP
dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir
dcp_state: dict[str, Any] = {}
if exclude_keys is None or "model" not in exclude_keys:
dcp_state["model"] = self.model.state_dict()
if exclude_keys is None or "optimizer" not in exclude_keys:
dcp_state["optimizer"] = self.optimizer.state_dict()
if dcp_state:
dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group)
if "model" in dcp_state:
self.model.load_state_dict(dcp_state["model"])
if "optimizer" in dcp_state:
self.optimizer.load_state_dict(dcp_state["optimizer"])
# Load non-distributed state. On NFS/Lustre, independent stat()
# calls can disagree briefly across ranks; if some ranks enter
# this branch and others don't, the broadcast_object_list below
# hangs. Use a rank-0-authoritative existence check broadcast to
# all ranks so every rank takes the same branch.
train_state_path = ckpt_dir / _TRAIN_STATE_FILE
if dist.is_initialized():
exists_flag = [train_state_path.exists() if self._rank == 0 else False]
dist.broadcast_object_list(exists_flag, src=0)
train_state_exists = bool(exists_flag[0])
else:
train_state_exists = train_state_path.exists()
if train_state_exists:
train_state = (
torch.load(train_state_path, map_location="cpu", weights_only=False)
if self._rank == 0 or not dist.is_initialized()
else None
)
# Broadcast from rank 0 to all ranks
if dist.is_initialized():
object_list = [train_state if self._rank == 0 else None]
dist.broadcast_object_list(object_list, src=0)
train_state = object_list[0]
assert train_state is not None, "train_state broadcast failed"
# Stash dataloader state if the caller can't yet provide the loader
# object. Training loops construct the dataloader after load() so
# apply_dataloader_state() can restore it once it exists.
if dataloader is None and "dataloader" in train_state:
self._pending_dataloader_state = train_state["dataloader"]
step, tokens_seen, extra = restore_train_state(
train_state,
scheduler=scheduler,
dataloader=dataloader,
)
logger.info(f"Resumed from step {step}, {tokens_seen:,} tokens seen")
return step, tokens_seen, extra
return 0, 0, {}
[docs]
def apply_dataloader_state(self, dataloader: Any) -> None:
"""Apply any dataloader state stashed during load().
Training loops call load() before constructing the dataloader (since
the dataloader depends on phase/annealing state that load() restores).
This method applies the stashed state once the loader exists.
No-op if no state is pending, or if the loader does not support
``load_state_dict`` (e.g., plain torch DataLoader for HF streaming).
"""
if self._pending_dataloader_state is None:
return
if dataloader is None or not hasattr(dataloader, "load_state_dict"):
self._pending_dataloader_state = None
return
dataloader.load_state_dict(self._pending_dataloader_state)
self._pending_dataloader_state = None
logger.info("Applied stashed dataloader state")
def _resolve_load_path(self, path: str | None = None) -> Path | None:
"""Resolve the checkpoint path to load from."""
if path is not None:
p = Path(path)
return p if p.exists() else None
if self.config.load_path:
p = Path(self.config.load_path)
return p if p.exists() else None
latest = self._latest_link()
if latest.exists():
return latest.resolve()
return None
def _cleanup(self) -> None:
"""Remove old checkpoints beyond the retention limit."""
keep = self.config.keep_last_n
if keep <= 0:
return
# Find all step_N directories
ckpt_dirs = sorted(
(d for d in self.base_dir.iterdir() if d.is_dir() and d.name.startswith("step_")),
key=lambda d: int(d.name.split("_")[1]),
)
# Remove oldest beyond retention
to_remove = ckpt_dirs[:-keep] if len(ckpt_dirs) > keep else []
for d in to_remove:
shutil.rmtree(d)
logger.info(f"Removed old checkpoint: {d}")