kempnerforge.checkpoint.manager¶
Checkpoint manager for distributed checkpointing.
Uses PyTorch Distributed Checkpoint (DCP) for model and optimizer state, which supports automatic resharding (save with N GPUs, load with M GPUs).
Non-distributed state (scheduler, dataloader, training meta, RNG) is saved separately as a torch file and broadcast from rank 0 on load.
Classes
Manages save/load/cleanup of distributed checkpoints. |
- class kempnerforge.checkpoint.manager.CheckpointManager[source]¶
Bases:
objectManages save/load/cleanup of distributed checkpoints.
Each checkpoint is stored in a subdirectory:
{dir}/step_{N}/containing DCP shards and a non-distributed training state file.A
latestsymlink always points to the most recent checkpoint.- Parameters:
config – Checkpoint configuration.
model – The model (FSDP-wrapped or plain).
optimizer – The optimizer.
- __init__(config, model, optimizer, process_group=None, pp_rank=None)[source]¶
- Parameters:
config (CheckpointConfig)
model (torch.nn.Module)
optimizer (torch.optim.Optimizer)
pp_rank (int | None)
- Return type:
None
- save(step, tokens_seen=0, scheduler=None, dataloader=None, extra=None)[source]¶
Save a checkpoint at the given step.
- load(path=None, scheduler=None, dataloader=None, exclude_keys=None)[source]¶
Load a checkpoint and restore all state.
- Parameters:
path (str | None) – Checkpoint path. If None, loads from
config.load_pathor thelatestsymlink.scheduler (Any | None) – LR scheduler to restore.
dataloader (Any | None) – Stateful dataloader to restore.
exclude_keys (list[str] | None) – DCP state keys to skip (e.g., [“optimizer”] for fine-tuning).
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via
build_train_state(extra=...).- Return type:
- apply_dataloader_state(dataloader)[source]¶
Apply any dataloader state stashed during load().
Training loops call load() before constructing the dataloader (since the dataloader depends on phase/annealing state that load() restores). This method applies the stashed state once the loader exists.
No-op if no state is pending, or if the loader does not support
load_state_dict(e.g., plain torch DataLoader for HF streaming).- Parameters:
dataloader (Any)
- Return type:
None