kempnerforge.checkpoint.manager

Checkpoint manager for distributed checkpointing.

Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state, which supports automatic resharding (save with N GPUs, load with M GPUs).

Non-distributed state (scheduler, dataloader, training meta, RNG) is saved separately as a torch file and broadcast from rank 0 on load.

Classes

CheckpointManager

Manages save/load/cleanup of distributed checkpoints.

class kempnerforge.checkpoint.manager.CheckpointManager[source]

Bases: object

Manages save/load/cleanup of distributed checkpoints.

Each checkpoint is stored in a subdirectory: {dir}/step_{N}/ containing DCP shards and a non-distributed training state file.

A latest symlink always points to the most recent checkpoint.

Parameters:
  • config – Checkpoint configuration.

  • model – The model (FSDP-wrapped or plain).

  • optimizer – The optimizer.

__init__(config, model, optimizer, process_group=None, pp_rank=None)[source]
Parameters:
Return type:

None

save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]

Save a checkpoint at the given step.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed.

  • scheduler (Any | None) – LR scheduler to save.

  • dataloader (Any | None) – Stateful dataloader to save.

  • extra (dict | None) – Additional metadata.

Return type:

None

wait()[source]

Block until any pending async checkpoint save completes.

Once the flush is durable, commit its deferred latest symlink + cleanup. The training loop calls this after the loop exits, so the final checkpoint’s latest is committed before process teardown.

Return type:

None

flush_pending_save()[source]

Drain any in-flight async save before mutating model state.

Called from the FreezeStage transition hook in the training loop: when a transition fires at step S, any save started at step S-1 must have written metadata.json with the pre-transition spec before the transition flips requires_grad. Otherwise metadata.json lands with the post-transition spec attached to the pre-transition shards.

Also commits the deferred latest symlink for that save, so a transition (or any caller draining the queue) leaves latest pointed at the now-durable checkpoint.

Return type:

None

peek_saved_step(path=None)[source]

Read step from a candidate checkpoint’s metadata.json.

Returns None if no checkpoint resolves or the metadata is missing/unreadable. Used by the training loop on resume to compute the expected freeze list (which depends on saved_step) before calling load.

Parameters:

path (str | None)

Return type:

int | None

load(path=None, scheduler=None, dataloader=None, exclude_keys=None, vlm_freeze_expected=None)[source]

Load a checkpoint and restore all state.

Parameters:
  • path (str | None) – Checkpoint path. If None, loads from config.load_path or the latest symlink.

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

  • exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).

  • vlm_freeze_expected (list[dict[str, Any]] | None) – Canonical freeze metadata (output of canonical_freeze_meta) for the current run’s VLMConfig. When both the saved metadata and this argument are set, a mismatch raises ValueError unless the checkpoint config has ignore_freeze_mismatch=True, in which case the load proceeds with a warning.

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=...).

Return type:

tuple[int, int, dict[str, Any]]

apply_dataloader_state(dataloader)[source]

Apply any dataloader state stashed during load().

Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.

No-op if no state is pending, or if the loader does not support load_state_dict (e.g., plain torch DataLoader for HF streaming).

Parameters:

dataloader (Any)

Return type:

None