Source code for kempnerforge.checkpoint.state

"""Training state assembly for checkpointing.

Collects the full training state — model, optimizer, scheduler, dataloader,
training metadata, and RNG states — into a single dict for DCP save/load.

RNG state capture ensures exact reproducibility on resume.
"""

from __future__ import annotations

import logging
import random
from typing import Any

import numpy as np
import torch

logger = logging.getLogger(__name__)


[docs] def get_rng_state() -> dict[str, Any]: """Capture all RNG states for reproducibility on resume.""" state = { "python": random.getstate(), "numpy": np.random.get_state(), "torch_cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): state["torch_cuda"] = torch.cuda.get_rng_state() return state
[docs] def set_rng_state(state: dict[str, Any]) -> None: """Restore all RNG states from a checkpoint.""" if "python" in state: random.setstate(state["python"]) if "numpy" in state: np.random.set_state(state["numpy"]) if "torch_cpu" in state: torch.random.set_rng_state(state["torch_cpu"]) if "torch_cuda" in state and torch.cuda.is_available(): torch.cuda.set_rng_state(state["torch_cuda"])
[docs] def build_train_state( step: int, tokens_seen: int, scheduler: Any | None = None, dataloader: Any | None = None, extra: dict | None = None, ) -> dict[str, Any]: """Build the non-distributed portion of the training state. Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption. Args: step: Current training step. tokens_seen: Total tokens processed so far. scheduler: LR scheduler (must have state_dict()). dataloader: Stateful dataloader (must have state_dict()). extra: Additional metadata to include. Returns: Dict with training state, scheduler state, dataloader state, and RNG states. """ state: dict[str, Any] = { "step": step, "tokens_seen": tokens_seen, "rng": get_rng_state(), } if scheduler is not None: state["scheduler"] = scheduler.state_dict() if dataloader is not None and hasattr(dataloader, "state_dict"): state["dataloader"] = dataloader.state_dict() if extra: state.update(extra) return state
[docs] def restore_train_state( state: dict[str, Any], scheduler: Any | None = None, dataloader: Any | None = None, ) -> tuple[int, int, dict[str, Any]]: """Restore the non-distributed portion of the training state. Args: state: Training state dict (from build_train_state). scheduler: LR scheduler to restore. dataloader: Stateful dataloader to restore. Returns: Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=...). """ step = state.get("step", 0) tokens_seen = state.get("tokens_seen", 0) if "rng" in state: set_rng_state(state["rng"]) logger.info("Restored RNG states") if scheduler is not None and "scheduler" in state: scheduler.load_state_dict(state["scheduler"]) logger.info("Restored scheduler state") if dataloader is not None and "dataloader" in state and hasattr(dataloader, "load_state_dict"): dataloader.load_state_dict(state["dataloader"]) logger.info("Restored dataloader state") _standard_keys = {"step", "tokens_seen", "rng", "scheduler", "dataloader"} extra = {k: v for k, v in state.items() if k not in _standard_keys} return step, tokens_seen, extra