kempnerforge.checkpoint¶
Distributed checkpointing for KempnerForge.
- Public API:
CheckpointManager: Save/load/cleanup distributed checkpoints
AsyncCheckpointer: Non-blocking checkpoint saves
build_train_state / restore_train_state: State assembly
- class kempnerforge.checkpoint.AsyncCheckpointer[source]¶
Bases:
objectNon-blocking checkpoint saver.
Wraps
dcp.async_save()and manages the background save future. Each new save waits for the previous async save to complete first.- Parameters:
mode – Checkpoint mode (disabled/async/async_with_pinned_mem).
- __init__(mode=AsyncCheckpointMode.disabled)[source]¶
- Parameters:
mode (AsyncCheckpointMode)
- Return type:
None
- save(state_dict, checkpoint_id, process_group=None)[source]¶
Save distributed state, potentially asynchronously.
- Parameters:
state_dict (dict) – DCP-compatible state dict (model + optimizer).
checkpoint_id (str) – Checkpoint directory path.
process_group – Process group for DCP. Required for PP where each stage has a different state dict — pass a group scoped to ranks within the same PP stage. None uses the default global group.
- Return type:
None
- class kempnerforge.checkpoint.CheckpointManager[source]¶
Bases:
objectManages save/load/cleanup of distributed checkpoints.
Each checkpoint is stored in a subdirectory:
{dir}/step_{N}/containing DCP shards and a non-distributed training state file.A
latestsymlink always points to the most recent checkpoint.- Parameters:
config – Checkpoint configuration.
model – The model (FSDP-wrapped or plain).
optimizer – The optimizer.
- __init__(config, model, optimizer, process_group=None, pp_rank=None)[source]¶
- Parameters:
config (CheckpointConfig)
model (torch.nn.Module)
optimizer (torch.optim.Optimizer)
pp_rank (int | None)
- Return type:
None
- save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]¶
Save a checkpoint at the given step.
- wait()[source]¶
Block until any pending async checkpoint save completes.
Once the flush is durable, commit its deferred latest symlink + cleanup. The training loop calls this after the loop exits, so the final checkpoint’s latest is committed before process teardown.
- Return type:
None
- flush_pending_save()[source]¶
Drain any in-flight async save before mutating model state.
Called from the FreezeStage transition hook in the training loop: when a transition fires at step S, any save started at step S-1 must have written
metadata.jsonwith the pre-transition spec before the transition flipsrequires_grad. Otherwisemetadata.jsonlands with the post-transition spec attached to the pre-transition shards.Also commits the deferred latest symlink for that save, so a transition (or any caller draining the queue) leaves latest pointed at the now-durable checkpoint.
- Return type:
None
- peek_saved_step(path=None)[source]¶
Read
stepfrom a candidate checkpoint’s metadata.json.Returns
Noneif no checkpoint resolves or the metadata is missing/unreadable. Used by the training loop on resume to compute the expected freeze list (which depends onsaved_step) before callingload.
- load(path=None, scheduler=None, dataloader=None, exclude_keys=None, vlm_freeze_expected=None)[source]¶
Load a checkpoint and restore all state.
- Parameters:
path (str | None) – Checkpoint path. If None, loads from
config.load_pathor thelatestsymlink.scheduler (Any | None) – LR scheduler to restore.
dataloader (Any | None) – Stateful dataloader to restore.
exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).
vlm_freeze_expected (list[dict[str, Any]] | None) – Canonical freeze metadata (output of
canonical_freeze_meta) for the current run’s VLMConfig. When both the saved metadata and this argument are set, a mismatch raisesValueErrorunless the checkpoint config hasignore_freeze_mismatch=True, in which case the load proceeds with a warning.
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via
build_train_state(extra=...).- Return type:
- apply_dataloader_state(dataloader)[source]¶
Apply any dataloader state stashed during load().
Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.
No-op if no state is pending, or if the loader does not support
load_state_dict(e.g., plain torch DataLoader for HF streaming).- Parameters:
dataloader (Any)
- Return type:
None
- kempnerforge.checkpoint.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]¶
Build the non-distributed portion of the training state.
Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.
- Parameters:
- Returns:
Dict with training state, scheduler state, dataloader state, and RNG states.
- Return type:
- kempnerforge.checkpoint.get_rng_state()[source]¶
Capture all RNG states for reproducibility on resume.
- kempnerforge.checkpoint.restore_train_state(state, scheduler=None, dataloader=None)[source]¶
Restore the non-distributed portion of the training state.
- Parameters:
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).
- Return type:
Modules
Async checkpointing for non-blocking saves. |
|
Checkpoint manager for distributed checkpointing. |
|
Training state assembly for checkpointing. |