Checkpointing

Distributed checkpoints via torch.distributed.checkpoint (DCP): what’s saved, how resharding works, auto-resume rules, and HuggingFace interchange.

At a glance

Every checkpoint lands in {config.checkpoint.dir}/step_{N}/ and contains two kinds of state:

File(s)

Contents

Format

DCP shards (.distcp + .metadata)

Model + optimizer state, one shard per rank

DCP § Model + optimizer

train_state.pt

step, tokens_seen, scheduler, RNG, extras (e.g. phase_idx, wandb_run_id)

Train state

metadata.json

Human-readable {"step": N, "tokens_seen": M}

Plain JSON

latest symlink

Points at the most recent step_N (updated atomically)

Auto-resume

Key modules