kempnerforge.checkpoint¶
Distributed checkpointing for KempnerForge.
- Public API:
CheckpointManager: Save/load/cleanup distributed checkpoints
AsyncCheckpointer: Non-blocking checkpoint saves
build_train_state / restore_train_state: State assembly
- class kempnerforge.checkpoint.AsyncCheckpointer[source]¶
Bases:
objectNon-blocking checkpoint saver.
Wraps
dcp.async_save()and manages the background save future. Each new save waits for the previous async save to complete first.- Parameters:
mode – Checkpoint mode (disabled/async/async_with_pinned_mem).
- __init__(mode=AsyncCheckpointMode.disabled)[source]¶
- Parameters:
mode (AsyncCheckpointMode)
- Return type:
None
- save(state_dict, checkpoint_id, process_group=None)[source]¶
Save distributed state, potentially asynchronously.
- Parameters:
state_dict (dict) – DCP-compatible state dict (model + optimizer).
checkpoint_id (str) – Checkpoint directory path.
process_group – Process group for DCP. Required for PP where each stage has a different state dict — pass a group scoped to ranks within the same PP stage. None uses the default global group.
- Return type:
None
- class kempnerforge.checkpoint.CheckpointManager[source]¶
Bases:
objectManages save/load/cleanup of distributed checkpoints.
Each checkpoint is stored in a subdirectory:
{dir}/step_{N}/containing DCP shards and a non-distributed training state file.A
latestsymlink always points to the most recent checkpoint.- Parameters:
config – Checkpoint configuration.
model – The model (FSDP-wrapped or plain).
optimizer – The optimizer.
- __init__(config, model, optimizer, process_group=None, pp_rank=None)[source]¶
- Parameters:
config (CheckpointConfig)
model (torch.nn.Module)
optimizer (torch.optim.Optimizer)
pp_rank (int | None)
- Return type:
None
- save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]¶
Save a checkpoint at the given step.
- load(path=None, scheduler=None, dataloader=None, exclude_keys=None)[source]¶
Load a checkpoint and restore all state.
- Parameters:
path (str | None) – Checkpoint path. If None, loads from
config.load_pathor thelatestsymlink.scheduler (Any | None) – LR scheduler to restore.
dataloader (Any | None) – Stateful dataloader to restore.
exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via
build_train_state(extra=...).- Return type:
- apply_dataloader_state(dataloader)[source]¶
Apply any dataloader state stashed during load().
Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.
No-op if no state is pending, or if the loader does not support
load_state_dict(e.g., plain torch DataLoader for HF streaming).- Parameters:
dataloader (Any)
- Return type:
None
- kempnerforge.checkpoint.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]¶
Build the non-distributed portion of the training state.
Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.
- Parameters:
- Returns:
Dict with training state, scheduler state, dataloader state, and RNG states.
- Return type:
- kempnerforge.checkpoint.get_rng_state()[source]¶
Capture all RNG states for reproducibility on resume.
- kempnerforge.checkpoint.restore_train_state(state, scheduler=None, dataloader=None)[source]¶
Restore the non-distributed portion of the training state.
- Parameters:
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).
- Return type:
Modules
Async checkpointing for non-blocking saves. |
|
Checkpoint manager for distributed checkpointing. |
|
Training state assembly for checkpointing. |