kempnerforge.checkpoint.state¶
Training state assembly for checkpointing.
Collects the full training state — model, optimizer, scheduler, dataloader, training metadata, and RNG states — into a single dict for DCP save/load.
RNG state capture ensures exact reproducibility on resume.
Functions
|
Build the non-distributed portion of the training state. |
Capture all RNG states for reproducibility on resume. |
|
|
Restore the non-distributed portion of the training state. |
|
Restore all RNG states from a checkpoint. |
- kempnerforge.checkpoint.state.get_rng_state()[source]¶
Capture all RNG states for reproducibility on resume.
- kempnerforge.checkpoint.state.set_rng_state(state)[source]¶
Restore all RNG states from a checkpoint.
- kempnerforge.checkpoint.state.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]¶
Build the non-distributed portion of the training state.
Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.
- Parameters:
- Returns:
Dict with training state, scheduler state, dataloader state, and RNG states.
- Return type:
- kempnerforge.checkpoint.state.restore_train_state(state, scheduler=None, dataloader=None)[source]¶
Restore the non-distributed portion of the training state.
- Parameters:
- Returns:
Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).
- Return type: