kempnerforge.checkpoint.manager¶
Checkpoint manager for distributed checkpointing.
Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state, which supports automatic resharding (save with N GPUs, load with M GPUs).
Non-distributed state (scheduler, dataloader, training meta, RNG) is saved separately as a torch file and broadcast from rank 0 on load.
Classes
Manages save/load/cleanup of distributed checkpoints. |
- class kempnerforge.checkpoint.manager.CheckpointManager[source]¶
Bases:
objectManages save/load/cleanup of distributed checkpoints.
Each checkpoint is stored in a subdirectory:
{dir}/step_{N}/containing DCP shards and a non-distributed training state file.A
latestsymlink always points to the most recent checkpoint.- Parameters:
config – Checkpoint configuration.
model – The model (FSDP-wrapped or plain).
optimizer – The optimizer.
- __init__(config, model, optimizer, process_group=None, pp_rank=None)[source]¶
- Parameters:
config (CheckpointConfig)
model (torch.nn.Module)
optimizer (torch.optim.Optimizer)
pp_rank (int | None)
- Return type:
None
- save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]¶
Save a checkpoint at the given step.
- wait()[source]¶
Block until any pending async checkpoint save completes.
Once the flush is durable, commit its deferred latest symlink + cleanup. The training loop calls this after the loop exits, so the final checkpoint’s latest is committed before process teardown.
- Return type:
None
- flush_pending_save()[source]¶
Drain any in-flight async save before mutating model state.
Called from the FreezeStage transition hook in the training loop: when a transition fires at step S, any save started at step S-1 must have written
metadata.jsonwith the pre-transition spec before the transition flipsrequires_grad. Otherwisemetadata.jsonlands with the post-transition spec attached to the pre-transition shards.Also commits the deferred latest symlink for that save, so a transition (or any caller draining the queue) leaves latest pointed at the now-durable checkpoint.
- Return type:
None
- peek_saved_step(path=None)[source]¶
Read
stepfrom a candidate checkpoint’s metadata.json.Returns
Noneif no checkpoint resolves or the metadata is missing/unreadable. Used by the training loop on resume to compute the expected freeze list (which depends onsaved_step) before callingload.
- load(path=None, scheduler=None, dataloader=None, exclude_keys=None, vlm_freeze_expected=None)[source]¶
Load a checkpoint and restore all state.
- Parameters:
path (str | None) – Checkpoint path. If None, loads from
config.load_pathor thelatestsymlink.scheduler (Any | None) – LR scheduler to restore.
dataloader (Any | None) – Stateful dataloader to restore.
exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).
vlm_freeze_expected (list[dict[str, Any]] | None) – Canonical freeze metadata (output of
canonical_freeze_meta) for the current run’s VLMConfig. When both the saved metadata and this argument are set, a mismatch raisesValueErrorunless the checkpoint config hasignore_freeze_mismatch=True, in which case the load proceeds with a warning.
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via
build_train_state(extra=...).- Return type:
- apply_dataloader_state(dataloader)[source]¶
Apply any dataloader state stashed during load().
Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.
No-op if no state is pending, or if the loader does not support
load_state_dict(e.g., plain torch DataLoader for HF streaming).- Parameters:
dataloader (Any)
- Return type:
None