kempnerforge.checkpoint

Distributed checkpointing for KempnerForge.

Public API:
  • CheckpointManager: Save/load/cleanup distributed checkpoints

  • AsyncCheckpointer: Non-blocking checkpoint saves

  • build_train_state / restore_train_state: State assembly

class kempnerforge.checkpoint.AsyncCheckpointer[source]

Bases: object

Non-blocking checkpoint saver.

Wraps dcp.async_save() and manages the background save future. Each new save waits for the previous async save to complete first.

Parameters:

mode – Checkpoint mode (disabled/async/async_with_pinned_mem).

__init__(mode=AsyncCheckpointMode.disabled)[source]
Parameters:

mode (AsyncCheckpointMode)

Return type:

None

save(state_dict, checkpoint_id, process_group=None)[source]

Save distributed state, potentially asynchronously.

Parameters:
  • state_dict (dict) – DCP-compatible state dict (model + optimizer).

  • checkpoint_id (str) – Checkpoint directory path.

  • process_group – Process group for DCP. Required for PP where each stage has a different state dict — pass a group scoped to ranks within the same PP stage. None uses the default global group.

Return type:

None

wait()[source]

Block until any pending async save completes.

Return type:

None

property is_pending: bool

Check if an async save is still in progress.

class kempnerforge.checkpoint.CheckpointManager[source]

Bases: object

Manages save/load/cleanup of distributed checkpoints.

Each checkpoint is stored in a subdirectory: {dir}/step_{N}/ containing DCP shards and a non-distributed training state file.

A latest symlink always points to the most recent checkpoint.

Parameters:
  • config – Checkpoint configuration.

  • model – The model (FSDP-wrapped or plain).

  • optimizer – The optimizer.

__init__(config, model, optimizer, process_group=None, pp_rank=None)[source]
Parameters:
Return type:

None

save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]

Save a checkpoint at the given step.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed.

  • scheduler (Any | None) – LR scheduler to save.

  • dataloader (Any | None) – Stateful dataloader to save.

  • extra (dict | None) – Additional metadata.

Return type:

None

wait()[source]

Block until any pending async checkpoint save completes.

Return type:

None

load(path=None, scheduler=None, dataloader=None, exclude_keys=None)[source]

Load a checkpoint and restore all state.

Parameters:
  • path (str | None) – Checkpoint path. If None, loads from config.load_path or the latest symlink.

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

  • exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=...).

Return type:

tuple[int, int, dict[str, Any]]

apply_dataloader_state(dataloader)[source]

Apply any dataloader state stashed during load().

Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.

No-op if no state is pending, or if the loader does not support load_state_dict (e.g., plain torch DataLoader for HF streaming).

Parameters:

dataloader (Any)

Return type:

None

kempnerforge.checkpoint.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]

Build the non-distributed portion of the training state.

Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed so far.

  • scheduler (Any | None) – LR scheduler (must have state_dict()).

  • dataloader (Any | None) – Stateful dataloader (must have state_dict()).

  • extra (dict | None) – Additional metadata to include.

Returns:

Dict with training state, scheduler state, dataloader state, and RNG states.

Return type:

dict[str, Any]

kempnerforge.checkpoint.get_rng_state()[source]

Capture all RNG states for reproducibility on resume.

Return type:

dict[str, Any]

kempnerforge.checkpoint.restore_train_state(state, scheduler=None, dataloader=None)[source]

Restore the non-distributed portion of the training state.

Parameters:
  • state (dict[str, Any]) – Training state dict (from build_train_state).

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).

Return type:

tuple[int, int, dict[str, Any]]

kempnerforge.checkpoint.set_rng_state(state)[source]

Restore all RNG states from a checkpoint.

Parameters:

state (dict[str, Any])

Return type:

None

Modules

async_save

Async checkpointing for non-blocking saves.

manager

Checkpoint manager for distributed checkpointing.

state

Training state assembly for checkpointing.