Source code for kempnerforge.checkpoint.state
"""Training state assembly for checkpointing.
Collects the full training state — model, optimizer, scheduler, dataloader,
training metadata, and RNG states — into a single dict for DCP save/load.
RNG state capture ensures exact reproducibility on resume.
"""
from __future__ import annotations
import logging
import random
from typing import Any
import numpy as np
import torch
logger = logging.getLogger(__name__)
[docs]
def get_rng_state() -> dict[str, Any]:
"""Capture all RNG states for reproducibility on resume."""
state = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch_cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
state["torch_cuda"] = torch.cuda.get_rng_state()
return state
[docs]
def set_rng_state(state: dict[str, Any]) -> None:
"""Restore all RNG states from a checkpoint."""
if "python" in state:
random.setstate(state["python"])
if "numpy" in state:
np.random.set_state(state["numpy"])
if "torch_cpu" in state:
torch.random.set_rng_state(state["torch_cpu"])
if "torch_cuda" in state and torch.cuda.is_available():
torch.cuda.set_rng_state(state["torch_cuda"])
[docs]
def build_train_state(
step: int,
tokens_seen: int,
scheduler: Any | None = None,
dataloader: Any | None = None,
extra: dict | None = None,
) -> dict[str, Any]:
"""Build the non-distributed portion of the training state.
Model and optimizer state are handled by DCP directly.
This function captures everything else needed for exact resumption.
Args:
step: Current training step.
tokens_seen: Total tokens processed so far.
scheduler: LR scheduler (must have state_dict()).
dataloader: Stateful dataloader (must have state_dict()).
extra: Additional metadata to include.
Returns:
Dict with training state, scheduler state, dataloader state, and RNG states.
"""
state: dict[str, Any] = {
"step": step,
"tokens_seen": tokens_seen,
"rng": get_rng_state(),
}
if scheduler is not None:
state["scheduler"] = scheduler.state_dict()
if dataloader is not None and hasattr(dataloader, "state_dict"):
state["dataloader"] = dataloader.state_dict()
if extra:
state.update(extra)
return state
[docs]
def restore_train_state(
state: dict[str, Any],
scheduler: Any | None = None,
dataloader: Any | None = None,
) -> tuple[int, int, dict[str, Any]]:
"""Restore the non-distributed portion of the training state.
Args:
state: Training state dict (from build_train_state).
scheduler: LR scheduler to restore.
dataloader: Stateful dataloader to restore.
Returns:
Tuple of (step, tokens_seen, extra) where extra contains any
additional keys saved via build_train_state(extra=...).
"""
step = state.get("step", 0)
tokens_seen = state.get("tokens_seen", 0)
if "rng" in state:
set_rng_state(state["rng"])
logger.info("Restored RNG states")
if scheduler is not None and "scheduler" in state:
scheduler.load_state_dict(state["scheduler"])
logger.info("Restored scheduler state")
if dataloader is not None and "dataloader" in state and hasattr(dataloader, "load_state_dict"):
dataloader.load_state_dict(state["dataloader"])
logger.info("Restored dataloader state")
_standard_keys = {"step", "tokens_seen", "rng", "scheduler", "dataloader"}
extra = {k: v for k, v in state.items() if k not in _standard_keys}
return step, tokens_seen, extra