Source code for kempnerforge.checkpoint.manager

"""Checkpoint manager for distributed checkpointing.

Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state,
which supports automatic resharding (save with N GPUs, load with M GPUs).

Non-distributed state (scheduler, dataloader, training meta, RNG) is saved
separately as a torch file and broadcast from rank 0 on load.
"""

from __future__ import annotations

import json
import logging
import os
import shutil
import stat
from pathlib import Path
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp

from kempnerforge.checkpoint.async_save import AsyncCheckpointer
from kempnerforge.checkpoint.state import build_train_state, restore_train_state
from kempnerforge.config.schema import AsyncCheckpointMode, CheckpointConfig

logger = logging.getLogger(__name__)

# Filename for non-distributed training state within a checkpoint directory
_TRAIN_STATE_FILE = "train_state.pt"
_METADATA_FILE = "metadata.json"
# DCP writes this file LAST, once all shards are durable. Its presence is the
# authoritative signal that a checkpoint's distributed state is loadable.
_DCP_METADATA_FILE = ".metadata"


def _intersect_freeze_meta_by_module(
    saved: list[dict[str, Any]],
    expected: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    """Filter both freeze-metadata lists to the intersection of module keys.

    Used at checkpoint load time to make cross-arch resumes work cleanly:
    a Joint-Decoder checkpoint's ``vlm_freeze`` has only
    ``vision_encoder``; a Cross-Attention config's expected metadata has
    ``vision_encoder`` + ``cross_attention`` (auto-default in
    ``CrossAttentionConfig.module_patterns``). Loading JD into CA: the
    ``cross_attention`` entry is in ``expected`` but not in ``saved``,
    so it gets dropped from ``expected``; the remaining
    ``vision_encoder`` entries compare cleanly.

    Real semantic mismatches on shared keys are preserved. If both
    sides have ``vision_encoder`` but with different ``frozen`` values,
    the filtered lists differ and the caller raises.

    The lists are already canonicalized (sorted, deduped) by
    ``canonical_freeze_meta``, so the filter preserves canonical order.
    """
    saved_keys = {e["module"] for e in saved}
    expected_keys = {e["module"] for e in expected}
    shared = saved_keys & expected_keys
    return (
        [e for e in saved if e["module"] in shared],
        [e for e in expected if e["module"] in shared],
    )


def _load_train_state(path: Path) -> dict[str, Any]:
    """Load ``train_state.pt`` under an explicit trust boundary.

    ``train_state.pt`` carries scheduler state, dataloader state, and a
    caller-supplied ``extra`` dict, so it is loaded with ``weights_only=False``
    (i.e. full pickle). Any object in the file whose class defines
    ``__reduce__`` runs arbitrary Python during ``torch.load``. On shared
    filesystems this is a real attack surface: anyone who can write into
    another user's checkpoint directory gets code execution in that user's
    training process on next resume.

    Refuses to load files not owned by the current UID and warns when the
    file is group- or world-writable. This does not defend against a
    same-UID compromise — if the attacker can write as you, they already
    win — but it closes the common "group-writable shared checkpoint dir"
    foot-gun and makes the trust boundary visible.

    Checkpoints imported from outside the lab (HuggingFace Hub, colleague
    transfers, etc.) will fail this check and must be either chown'd to the
    current user after inspection or converted to a weights-only-safe form.
    """
    st = path.stat()
    uid = os.getuid()
    if st.st_uid != uid:
        raise PermissionError(
            f"Refusing to load {path}: owned by uid={st.st_uid}, current uid={uid}. "
            f"train_state.pt is a pickle and loading it executes arbitrary Python. "
            f"If you trust this checkpoint, chown it to the current user after inspection."
        )
    if st.st_mode & (stat.S_IWGRP | stat.S_IWOTH):
        logger.warning(
            f"{path} is group/world-writable (mode={oct(st.st_mode & 0o777)}); "
            f"train_state.pt is a pickle and any writer can inject arbitrary code "
            f"at load time. Consider chmod g-w,o-w on the checkpoint directory."
        )
    return torch.load(path, map_location="cpu", weights_only=False)


[docs] class CheckpointManager: """Manages save/load/cleanup of distributed checkpoints. Each checkpoint is stored in a subdirectory: ``{dir}/step_{N}/`` containing DCP shards and a non-distributed training state file. A ``latest`` symlink always points to the most recent checkpoint. Args: config: Checkpoint configuration. model: The model (FSDP-wrapped or plain). optimizer: The optimizer. """
[docs] def __init__( self, config: CheckpointConfig, model: torch.nn.Module, optimizer: torch.optim.Optimizer, process_group=None, pp_rank: int | None = None, ) -> None: self.config = config self.model = model self.optimizer = optimizer self.base_dir = Path(config.dir) self._rank = dist.get_rank() if dist.is_initialized() else 0 self._async_ckpt = AsyncCheckpointer(mode=config.async_mode) self._process_group = process_group self._pp_rank = pp_rank # Dataloader state stashed during load() when the caller cannot yet # provide a dataloader object. Applied later via # apply_dataloader_state() once the loader is constructed. self._pending_dataloader_state: dict[str, Any] | None = None # Async-only: a checkpoint whose DCP flush was dispatched but is not # yet durable. Its `latest` symlink swap + cleanup are deferred until # the flush completes (drained at the next save() or wait()), so # `latest` never points at a half-written checkpoint. (step, ckpt_dir). self._pending_finalize: tuple[int, Path] | None = None
def _checkpoint_dir(self, step: int) -> Path: return self.base_dir / f"step_{step}" def _latest_link(self) -> Path: return self.base_dir / "latest" def _dcp_dir(self, ckpt_dir: Path) -> Path: """DCP shard directory for a checkpoint (per-stage subdir under PP).""" return ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir def _dcp_complete(self, ckpt_dir: Path) -> bool: """True once DCP has written its `.metadata` (all shards durable).""" return (self._dcp_dir(ckpt_dir) / _DCP_METADATA_FILE).exists() def _commit_latest(self, step: int, ckpt_dir: Path) -> None: """Atomically point `latest` at a now-durable checkpoint, then prune. Rank 0 only. Called either inline (sync mode, DCP already durable) or deferred (async mode, after the flush future has resolved). """ latest = self._latest_link() tmp_link = latest.with_suffix(".tmp") tmp_link.unlink(missing_ok=True) tmp_link.symlink_to(ckpt_dir.name) tmp_link.rename(latest) logger.info(f"Checkpoint committed: {ckpt_dir} (step={step})") self._cleanup() def _drain_pending_finalize(self) -> None: """Commit a deferred async checkpoint once its flush is durable. Invoked after the pending DCP future has been awaited (next save() or an explicit wait()/flush). Rank 0 performs the symlink + cleanup; other ranks no-op (they never touch the symlink), mirroring save(). """ if self._pending_finalize is None: return step, ckpt_dir = self._pending_finalize self._pending_finalize = None if self._rank == 0: self._commit_latest(step, ckpt_dir)
[docs] def save( self, step: int, tokens_seen: int = 0, scheduler: Any | None = None, dataloader: Any | None = None, extra: dict | None = None, ) -> None: """Save a checkpoint at the given step. Args: step: Current training step. tokens_seen: Total tokens processed. scheduler: LR scheduler to save. dataloader: Stateful dataloader to save. extra: Additional metadata. """ ckpt_dir = self._checkpoint_dir(step) # Create directory (all ranks) ckpt_dir.mkdir(parents=True, exist_ok=True) # With PP, each stage has different parameters — save DCP shards to # a per-stage subdirectory to avoid .metadata file collisions. dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir dcp_dir.mkdir(parents=True, exist_ok=True) # Save distributed state (model + optimizer) via DCP dcp_state = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), } # Dispatch the DCP save. For async modes this returns immediately but # FIRST awaits the previous in-flight flush, so any deferred # `_pending_finalize` from the prior save is now durable. For sync # mode (disabled) dcp.save() blocks until THIS checkpoint is durable. self._async_ckpt.save( dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group ) sync_mode = self.config.async_mode == AsyncCheckpointMode.disabled # The prior async checkpoint's flush is now guaranteed durable (the # dispatch above awaited it). Commit its `latest` symlink + cleanup # before doing anything else. self._drain_pending_finalize() # Save non-distributed state (rank 0 only). These are small synchronous # writes; they finish well before `latest` is ever pointed here, so # the checkpoint dir is fully populated by commit time. if self._rank == 0: train_state = build_train_state( step=step, tokens_seen=tokens_seen, scheduler=scheduler, dataloader=dataloader, extra=extra, ) torch.save(train_state, ckpt_dir / _TRAIN_STATE_FILE) # Write human-readable metadata meta: dict[str, Any] = {"step": step, "tokens_seen": tokens_seen} if extra is not None and "vlm_freeze" in extra: # Already canonicalized by canonical_freeze_meta(...); stored as # a sorted, deduplicated list of {"module", "frozen"} dicts so # the comparison on load is reorder-invariant. meta["vlm_freeze"] = extra["vlm_freeze"] (ckpt_dir / _METADATA_FILE).write_text(json.dumps(meta, indent=2)) if sync_mode: # DCP shards are already durable — commit immediately. if self._rank == 0: self._commit_latest(step, ckpt_dir) else: # Async flush still in flight. Defer the `latest` swap + cleanup # until it is durable (drained at the next save() or wait()), so # a crash mid-flush leaves `latest` on the last GOOD step rather # than a half-written one whose DCP `.metadata` is absent. self._pending_finalize = (step, ckpt_dir) # save() is a collective: non-rank-0 ranks must not return until # rank-0 has written train_state.pt + metadata.json (and, in sync # mode, advanced `latest`). Without this barrier, post-save hooks or # readers on other ranks race rank-0's writes (especially NFS/Lustre). if dist.is_initialized(): dist.barrier()
[docs] def wait(self) -> None: """Block until any pending async checkpoint save completes. Once the flush is durable, commit its deferred `latest` symlink + cleanup. The training loop calls this after the loop exits, so the final checkpoint's `latest` is committed before process teardown. """ self._async_ckpt.wait() self._drain_pending_finalize() if dist.is_initialized(): dist.barrier()
[docs] def flush_pending_save(self) -> None: """Drain any in-flight async save before mutating model state. Called from the FreezeStage transition hook in the training loop: when a transition fires at step S, any save started at step S-1 must have written ``metadata.json`` with the pre-transition spec before the transition flips ``requires_grad``. Otherwise ``metadata.json`` lands with the post-transition spec attached to the pre-transition shards. Also commits the deferred `latest` symlink for that save, so a transition (or any caller draining the queue) leaves `latest` pointed at the now-durable checkpoint. """ self._async_ckpt.wait() self._drain_pending_finalize() if dist.is_initialized(): dist.barrier()
[docs] def peek_saved_step(self, path: str | None = None) -> int | None: """Read ``step`` from a candidate checkpoint's metadata.json. Returns ``None`` if no checkpoint resolves or the metadata is missing/unreadable. Used by the training loop on resume to compute the expected freeze list (which depends on ``saved_step``) before calling ``load``. """ ckpt_dir = self._resolve_load_path(path) if ckpt_dir is None: return None metadata_path = ckpt_dir / _METADATA_FILE if not metadata_path.exists(): return None try: saved_meta = json.loads(metadata_path.read_text()) except (OSError, json.JSONDecodeError): return None step = saved_meta.get("step") return int(step) if step is not None else None
[docs] def load( self, path: str | None = None, scheduler: Any | None = None, dataloader: Any | None = None, exclude_keys: list[str] | None = None, vlm_freeze_expected: list[dict[str, Any]] | None = None, ) -> tuple[int, int, dict[str, Any]]: """Load a checkpoint and restore all state. Args: path: Checkpoint path. If None, loads from ``config.load_path`` or the ``latest`` symlink. scheduler: LR scheduler to restore. dataloader: Stateful dataloader to restore. exclude_keys: DCP state keys to skip (e.g., ["optimizer"] for fine-tuning). vlm_freeze_expected: Canonical freeze metadata (output of ``canonical_freeze_meta``) for the current run's VLMConfig. When both the saved metadata and this argument are set, a mismatch raises ``ValueError`` unless the checkpoint config has ``ignore_freeze_mismatch=True``, in which case the load proceeds with a warning. Returns: Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via ``build_train_state(extra=...)``. """ ckpt_dir = self._resolve_load_path(path) if ckpt_dir is None: logger.info("No checkpoint found — starting from scratch") return 0, 0, {} # When a DCP load will occur, fall back off an interrupted async # flush to the newest complete checkpoint so the whole load (DCP # shards, train_state.pt, metadata.json) stays consistent on one # step. Skipped when DCP is fully excluded (e.g. fine-tuning that # loads only train_state), where DCP durability is irrelevant. will_load_dcp = ( exclude_keys is None or "model" not in exclude_keys or "optimizer" not in exclude_keys ) if will_load_dcp: ckpt_dir = self._resolve_dcp_load_dir(ckpt_dir, path) logger.info(f"Loading checkpoint: {ckpt_dir}") # Check VLM freeze metadata BEFORE loading DCP shards so a mismatch # surfaces without leaving partial state in the live model. # # Cross-arch load rule: filter both saved and expected to the # intersection of module keys. A JD checkpoint's vlm_freeze has # only ``vision_encoder``; a Cross-Attention config's expected # has ``vision_encoder`` + ``cross_attention`` (auto-default). # Loading JD into CA: ``cross_attention`` is in expected but not # saved -> drop from expected; remaining ``vision_encoder`` # entries compare cleanly. Real semantic mismatches on shared # keys (e.g., saved ``vision_encoder=True`` vs expected # ``vision_encoder=False``) still raise. if vlm_freeze_expected is not None: metadata_path = ckpt_dir / _METADATA_FILE if metadata_path.exists(): try: saved_meta = json.loads(metadata_path.read_text()) except (OSError, json.JSONDecodeError) as e: logger.warning(f"Could not read {metadata_path}: {e}") saved_meta = {} saved_vlm_freeze = saved_meta.get("vlm_freeze") if saved_vlm_freeze is not None: saved_filt, expected_filt = _intersect_freeze_meta_by_module( saved_vlm_freeze, vlm_freeze_expected ) if saved_filt != expected_filt: dropped_from_saved = sorted( {e["module"] for e in saved_vlm_freeze} - {e["module"] for e in saved_filt} ) dropped_from_expected = sorted( {e["module"] for e in vlm_freeze_expected} - {e["module"] for e in expected_filt} ) cross_arch_note = "" if dropped_from_saved or dropped_from_expected: cross_arch_note = ( f" (cross-arch keys ignored: " f"saved-only={dropped_from_saved}, " f"current-only={dropped_from_expected})" ) msg = ( f"VLM freeze mismatch at {ckpt_dir}: " f"saved={saved_filt}, current={expected_filt}" f"{cross_arch_note}" ) if getattr(self.config, "ignore_freeze_mismatch", False): logger.warning( msg + " — proceeding because checkpoint.ignore_freeze_mismatch=True" ) else: raise ValueError( msg + " (set checkpoint.ignore_freeze_mismatch=true to override)" ) # Load distributed state via DCP dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir dcp_state: dict[str, Any] = {} if exclude_keys is None or "model" not in exclude_keys: dcp_state["model"] = self.model.state_dict() if exclude_keys is None or "optimizer" not in exclude_keys: dcp_state["optimizer"] = self.optimizer.state_dict() if dcp_state: dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group) if "model" in dcp_state: self.model.load_state_dict(dcp_state["model"]) if "optimizer" in dcp_state: self.optimizer.load_state_dict(dcp_state["optimizer"]) # Load non-distributed state. On NFS/Lustre, independent stat() # calls can disagree briefly across ranks; if some ranks enter # this branch and others don't, the broadcast_object_list below # hangs. Use a rank-0-authoritative existence check broadcast to # all ranks so every rank takes the same branch. train_state_path = ckpt_dir / _TRAIN_STATE_FILE if dist.is_initialized(): exists_flag = [train_state_path.exists() if self._rank == 0 else False] dist.broadcast_object_list(exists_flag, src=0) train_state_exists = bool(exists_flag[0]) else: train_state_exists = train_state_path.exists() if train_state_exists: train_state = ( _load_train_state(train_state_path) if self._rank == 0 or not dist.is_initialized() else None ) # Broadcast from rank 0 to all ranks. PyTorch 2.11's # broadcast_object_list does not accept async_op, so a per-op # timeout cannot be wired here — this call inherits the 1800s # process-group default. A wedged rank will still surface, just # later than the other fast-fail paths in this patch. if dist.is_initialized(): object_list = [train_state if self._rank == 0 else None] dist.broadcast_object_list(object_list, src=0) train_state = object_list[0] assert train_state is not None, "train_state broadcast failed" # Stash dataloader state if the caller can't yet provide the loader # object. Training loops construct the dataloader after load() so # apply_dataloader_state() can restore it once it exists. if dataloader is None and "dataloader" in train_state: self._pending_dataloader_state = train_state["dataloader"] step, tokens_seen, extra = restore_train_state( train_state, scheduler=scheduler, dataloader=dataloader, ) logger.info(f"Resumed from step {step}, {tokens_seen:,} tokens seen") return step, tokens_seen, extra return 0, 0, {}
[docs] def apply_dataloader_state(self, dataloader: Any) -> None: """Apply any dataloader state stashed during load(). Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists. No-op if no state is pending, or if the loader does not support ``load_state_dict`` (e.g., plain torch DataLoader for HF streaming). """ if self._pending_dataloader_state is None: return if dataloader is None or not hasattr(dataloader, "load_state_dict"): self._pending_dataloader_state = None return dataloader.load_state_dict(self._pending_dataloader_state) self._pending_dataloader_state = None logger.info("Applied stashed dataloader state")
def _newest_complete_checkpoint(self) -> Path | None: """Newest ``step_N`` dir whose DCP shards are durable, or None. Defense in depth: even though the Layer-1 deferral keeps `latest` off half-written checkpoints, a crash mid-flush (or a checkpoint dir left incomplete by older buggy code / external interference) can still leave the newest dir without DCP `.metadata`. Falling back to the newest COMPLETE dir keeps resume working instead of hard-failing in dcp.load with "metadata is None". """ if not self.base_dir.exists(): return None step_dirs = sorted( (d for d in self.base_dir.iterdir() if d.is_dir() and d.name.startswith("step_")), key=lambda d: int(d.name.split("_")[1]), reverse=True, ) for d in step_dirs: if self._dcp_complete(d): return d return None def _resolve_load_path(self, path: str | None = None) -> Path | None: """Resolve the checkpoint path to load from. Returns the raw resolution (explicit path, ``load_path``, or the ``latest`` symlink target). The DCP-durability fallback for an interrupted async flush is applied separately in ``load()``, scoped to the case where DCP state is actually being loaded — so it never interferes with DCP-excluded loads (e.g. fine-tuning). """ if path is not None: p = Path(path) return p if p.exists() else None if self.config.load_path: p = Path(self.config.load_path) return p if p.exists() else None latest = self._latest_link() if latest.exists(): return latest.resolve() return None def _resolve_dcp_load_dir(self, resolved: Path, path: str | None) -> Path: """Pick the dir to DCP-load from, falling back if interrupted. ``resolved`` is the ``_resolve_load_path`` result. When it was reached via auto-resume (no explicit path/``load_path``) and its DCP shards are not durable — the signature of a crash during an async flush — fall back to the newest complete checkpoint so resume degrades to "last good step" instead of hard-failing in ``dcp.load`` with "metadata is None". An explicitly requested path is honored as-is (caller intent; fail loudly if broken). """ explicit = path is not None or bool(self.config.load_path) if explicit or self._dcp_complete(resolved): return resolved fallback = self._newest_complete_checkpoint() if fallback is not None and fallback != resolved: logger.warning( f"`latest` -> {resolved} has no durable DCP metadata " f"(likely an interrupted async flush); resuming from " f"newest complete checkpoint {fallback} instead." ) return fallback return resolved def _cleanup(self) -> None: """Remove old checkpoints beyond the retention limit. Two directories are never removed regardless of retention: the current ``latest`` target and the in-flight async checkpoint (``_pending_finalize``). Pruning either would let a crash strand resume with no loadable checkpoint — the exact failure this fix exists to prevent. """ keep = self.config.keep_last_n if keep <= 0: return # Find all step_N directories ckpt_dirs = sorted( (d for d in self.base_dir.iterdir() if d.is_dir() and d.name.startswith("step_")), key=lambda d: int(d.name.split("_")[1]), ) protected: set[Path] = set() latest = self._latest_link() if latest.exists(): protected.add(latest.resolve()) if self._pending_finalize is not None: protected.add(self._pending_finalize[1].resolve()) # Remove oldest beyond retention, but never a protected dir. to_remove = ckpt_dirs[:-keep] if len(ckpt_dirs) > keep else [] for d in to_remove: if d.resolve() in protected: continue shutil.rmtree(d) logger.info(f"Removed old checkpoint: {d}")