kempnerforge.checkpoint

Distributed checkpointing for KempnerForge.

Public API:
  • CheckpointManager: Save/load/cleanup distributed checkpoints

  • AsyncCheckpointer: Non-blocking checkpoint saves

  • build_train_state / restore_train_state: State assembly

class kempnerforge.checkpoint.AsyncCheckpointer[source]

Bases: object

Non-blocking checkpoint saver.

Wraps dcp.async_save() and manages the background save future. Each new save waits for the previous async save to complete first.

Parameters:

mode – Checkpoint mode (disabled/async/async_with_pinned_mem).

__init__(mode=AsyncCheckpointMode.disabled)[source]
Parameters:

mode (AsyncCheckpointMode)

Return type:

None

save(state_dict, checkpoint_id, process_group=None)[source]

Save distributed state, potentially asynchronously.

Parameters:
  • state_dict (dict) – DCP-compatible state dict (model + optimizer).

  • checkpoint_id (str) – Checkpoint directory path.

  • process_group – Process group for DCP. Required for PP where each stage has a different state dict — pass a group scoped to ranks within the same PP stage. None uses the default global group.

Return type:

None

wait()[source]

Block until any pending async save completes.

Return type:

None

property is_pending: bool

Check if an async save is still in progress.

class kempnerforge.checkpoint.CheckpointManager[source]

Bases: object

Manages save/load/cleanup of distributed checkpoints.

Each checkpoint is stored in a subdirectory: {dir}/step_{N}/ containing DCP shards and a non-distributed training state file.

A latest symlink always points to the most recent checkpoint.

Parameters:
  • config – Checkpoint configuration.

  • model – The model (FSDP-wrapped or plain).

  • optimizer – The optimizer.

__init__(config, model, optimizer, process_group=None, pp_rank=None)[source]
Parameters:
Return type:

None

save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]

Save a checkpoint at the given step.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed.

  • scheduler (Any | None) – LR scheduler to save.

  • dataloader (Any | None) – Stateful dataloader to save.

  • extra (dict | None) – Additional metadata.

Return type:

None

wait()[source]

Block until any pending async checkpoint save completes.

Once the flush is durable, commit its deferred latest symlink + cleanup. The training loop calls this after the loop exits, so the final checkpoint’s latest is committed before process teardown.

Return type:

None

flush_pending_save()[source]

Drain any in-flight async save before mutating model state.

Called from the FreezeStage transition hook in the training loop: when a transition fires at step S, any save started at step S-1 must have written metadata.json with the pre-transition spec before the transition flips requires_grad. Otherwise metadata.json lands with the post-transition spec attached to the pre-transition shards.

Also commits the deferred latest symlink for that save, so a transition (or any caller draining the queue) leaves latest pointed at the now-durable checkpoint.

Return type:

None

peek_saved_step(path=None)[source]

Read step from a candidate checkpoint’s metadata.json.

Returns None if no checkpoint resolves or the metadata is missing/unreadable. Used by the training loop on resume to compute the expected freeze list (which depends on saved_step) before calling load.

Parameters:

path (str | None)

Return type:

int | None

load(path=None, scheduler=None, dataloader=None, exclude_keys=None, vlm_freeze_expected=None)[source]

Load a checkpoint and restore all state.

Parameters:
  • path (str | None) – Checkpoint path. If None, loads from config.load_path or the latest symlink.

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

  • exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).

  • vlm_freeze_expected (list[dict[str, Any]] | None) – Canonical freeze metadata (output of canonical_freeze_meta) for the current run’s VLMConfig. When both the saved metadata and this argument are set, a mismatch raises ValueError unless the checkpoint config has ignore_freeze_mismatch=True, in which case the load proceeds with a warning.

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=...).

Return type:

tuple[int, int, dict[str, Any]]

apply_dataloader_state(dataloader)[source]

Apply any dataloader state stashed during load().

Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.

No-op if no state is pending, or if the loader does not support load_state_dict (e.g., plain torch DataLoader for HF streaming).

Parameters:

dataloader (Any)

Return type:

None

kempnerforge.checkpoint.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]

Build the non-distributed portion of the training state.

Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed so far.

  • scheduler (Any | None) – LR scheduler (must have state_dict()).

  • dataloader (Any | None) – Stateful dataloader (must have state_dict()).

  • extra (dict | None) – Additional metadata to include.

Returns:

Dict with training state, scheduler state, dataloader state, and RNG states.

Return type:

dict[str, Any]

kempnerforge.checkpoint.get_rng_state()[source]

Capture all RNG states for reproducibility on resume.

Return type:

dict[str, Any]

kempnerforge.checkpoint.restore_train_state(state, scheduler=None, dataloader=None)[source]

Restore the non-distributed portion of the training state.

Parameters:
  • state (dict[str, Any]) – Training state dict (from build_train_state).

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).

Return type:

tuple[int, int, dict[str, Any]]

kempnerforge.checkpoint.set_rng_state(state)[source]

Restore all RNG states from a checkpoint.

Parameters:

state (dict[str, Any])

Return type:

None

Modules

async_save

Async checkpointing for non-blocking saves.

manager

Checkpoint manager for distributed checkpointing.

state

Training state assembly for checkpointing.