kempnerforge.checkpoint.manager

Checkpoint manager for distributed checkpointing.

Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state, which supports automatic resharding (save with N GPUs, load with M GPUs).

Non-distributed state (scheduler, dataloader, training meta, RNG) is saved separately as a torch file and broadcast from rank 0 on load.

Classes

CheckpointManager

Manages save/load/cleanup of distributed checkpoints.

class kempnerforge.checkpoint.manager.CheckpointManager[source]

Bases: object

Manages save/load/cleanup of distributed checkpoints.

Each checkpoint is stored in a subdirectory: {dir}/step_{N}/ containing DCP shards and a non-distributed training state file.

A latest symlink always points to the most recent checkpoint.

Parameters:
  • config – Checkpoint configuration.

  • model – The model (FSDP-wrapped or plain).

  • optimizer – The optimizer.

__init__(config, model, optimizer, process_group=None, pp_rank=None)[source]
Parameters:
Return type:

None

save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]

Save a checkpoint at the given step.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed.

  • scheduler (Any | None) – LR scheduler to save.

  • dataloader (Any | None) – Stateful dataloader to save.

  • extra (dict | None) – Additional metadata.

Return type:

None

wait()[source]

Block until any pending async checkpoint save completes.

Return type:

None

load(path=None, scheduler=None, dataloader=None, exclude_keys=None)[source]

Load a checkpoint and restore all state.

Parameters:
  • path (str | None) – Checkpoint path. If None, loads from config.load_path or the latest symlink.

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

  • exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=...).

Return type:

tuple[int, int, dict[str, Any]]

apply_dataloader_state(dataloader)[source]

Apply any dataloader state stashed during load().

Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.

No-op if no state is pending, or if the loader does not support load_state_dict (e.g., plain torch DataLoader for HF streaming).

Parameters:

dataloader (Any)

Return type:

None