kempnerforge.checkpoint.state

Training state assembly for checkpointing.

Collects the full training state — model, optimizer, scheduler, dataloader, training metadata, and RNG states — into a single dict for DCP save/load.

RNG state capture ensures exact reproducibility on resume.

Functions

build_train_state(step, tokens_seen[, ...])

Build the non-distributed portion of the training state.

get_rng_state()

Capture all RNG states for reproducibility on resume.

restore_train_state(state[, scheduler, ...])

Restore the non-distributed portion of the training state.

set_rng_state(state)

Restore all RNG states from a checkpoint.

kempnerforge.checkpoint.state.get_rng_state()[source]

Capture all RNG states for reproducibility on resume.

Return type:

dict[str, Any]

kempnerforge.checkpoint.state.set_rng_state(state)[source]

Restore all RNG states from a checkpoint.

Parameters:

state (dict[str, Any])

Return type:

None

kempnerforge.checkpoint.state.build_train_state(step, tokens_seen, scheduler=None, dataloader=None, extra=None)[source]

Build the non-distributed portion of the training state.

Model and optimizer state are handled by DCP directly. This function captures everything else needed for exact resumption.

Parameters:
  • step (int) – Current training step.

  • tokens_seen (int) – Total tokens processed so far.

  • scheduler (Any | None) – LR scheduler (must have state_dict()).

  • dataloader (Any | None) – Stateful dataloader (must have state_dict()).

  • extra (dict | None) – Additional metadata to include.

Returns:

Dict with training state, scheduler state, dataloader state, and RNG states.

Return type:

dict[str, Any]

kempnerforge.checkpoint.state.restore_train_state(state, scheduler=None, dataloader=None)[source]

Restore the non-distributed portion of the training state.

Parameters:
  • state (dict[str, Any]) – Training state dict (from build_train_state).

  • scheduler (Any | None) – LR scheduler to restore.

  • dataloader (Any | None) – Stateful dataloader to restore.

Returns:

Tuple of (step, tokens_seen, extra) where extra contains any additional keys saved via build_train_state(extra=…).

Return type:

tuple[int, int, dict[str, Any]]