DCP model + optimizer¶
KempnerForge uses
torch.distributed.checkpoint
(DCP) to save and load model and optimizer state. DCP is designed
for sharded state — each rank writes only its local slice, the
reader automatically reshards into whatever parallelism the loader
has. No “rank 0 gathers everything” step.
Entry point:
CheckpointManager.save()
in kempnerforge/checkpoint/manager.py.
What goes into the DCP shard¶
dcp_state = {
"model": get_model_state_dict(self.model),
"optimizer": get_optimizer_state_dict(self.model, self.optimizer),
}
self._async_ckpt.save(
dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group
)
get_model_state_dict / get_optimizer_state_dict are the DCP-aware
helpers from torch.distributed.checkpoint.state_dict — not raw
model.state_dict() / optimizer.state_dict(). They key the
optimizer state by parameter fully-qualified name (not positional
index) and keep the FSDP/DTensor sharding intact, which is what makes
load (and resharding) line up by name. See Loading for why
the raw calls break resume.
Two top-level keys — "model" and "optimizer". DCP introspects the
state dicts, finds DTensor / ShardedTensor parameters, and
writes each shard to disk with enough metadata to reassemble.
What’s in each:
model state — every parameter and buffer: weights, RMSNorm scales, learned RoPE frequencies (if present), and any registered buffer. Under FSDP2 these are
DTensors; under TP they’reDTensors on a 2D mesh. DCP handles both.optimizer state — AdamW’s
exp_avg,exp_avg_sq,stepcounters; Lion’sexp_avg; Muon’s internal state. All per-parameter tensors live on the same device and parallelism shape as the parameter, so DCP saves them symmetrically.
Not in the DCP shard: scheduler, dataloader, RNG, and training
metadata. Those go in train_state.pt alongside —
see Train state.
On-disk layout¶
checkpoints/
├── step_1000/
│ ├── .metadata ← DCP global manifest
│ ├── __0_0.distcp ← rank-0 model shard
│ ├── __1_0.distcp ← rank-1 model shard
│ ├── ... ← one `.distcp` per rank
│ ├── train_state.pt ← non-distributed state
│ └── metadata.json ← human-readable step info
├── step_2000/ ← same layout
└── latest -> step_2000 ← symlink
.distcp files are the actual tensor storage (one per rank, written
in parallel). .metadata is the global index that lets the reader
figure out which .distcp file contains what.
With pipeline parallelism¶
PP makes each stage hold a different set of parameters — DCP needs
disjoint shards per stage, so save() writes to a per-stage
subdirectory:
# manager.py
dcp_dir = ckpt_dir / f"pp{self._pp_rank}" if self._pp_rank is not None else ckpt_dir
checkpoints/step_1000/
├── pp0/ ← stage 0 shards (embedding + first layers)
│ ├── .metadata
│ └── __*_0.distcp
├── pp1/ ← stage 1 shards
│ ├── .metadata
│ └── __*_0.distcp
├── pp2/ ← ...
├── pp3/ ← last stage (final norm + output head)
└── train_state.pt ← one per checkpoint, written by global rank 0
Each stage also gets a process group scoped to that stage’s DP + TP ranks:
# scripts/train.py
non_pp_dims = [d for d in device_mesh.mesh_dim_names if d != "pp"]
if len(non_pp_dims) == 1:
ckpt_pg = device_mesh[non_pp_dims[0]].get_group()
elif len(non_pp_dims) > 1:
ckpt_pg = device_mesh[tuple(non_pp_dims)].get_group()
ckpt_mgr = CheckpointManager(config.checkpoint, model, optimizer,
process_group=ckpt_pg, pp_rank=pp_rank)
A 1-D mesh slice has to be indexed by the dim name directly;
tuple(...) wrapping is only valid for ≥2 dims. Both branches land
on the same thing semantically — every rank at the same PP position
coordinating together.
Without the scoped group, DCP would try to coordinate across all world ranks (including other PP stages), and the stage-0 ranks would hang waiting for stage-1’s shards.
Save modes¶
AsyncCheckpointer
wraps DCP’s save / async_save behind a mode flag:
Config |
Behavior |
Use |
|---|---|---|
|
|
Simple, debugging, small models |
|
|
Default for production |
|
Async with pinned-memory staging (faster GPU→CPU copy) |
Very large models where GPU→CPU throughput bottlenecks the snapshot |
Every save() call first does self.wait() — the previous async
save must fully complete before starting a new one. This avoids
holding two CPU snapshots at once and avoids racing on the same
directory.
The returned future is an AsyncCheckpointerFuture; wait() calls
.result() which blocks until the background writer has flushed to
disk. Training calls ckpt_mgr.wait() once before shutdown
(scripts/train.py line ~788) to flush any pending save.
Process groups¶
The process_group= kwarg on dcp.save and dcp.load scopes the
all-gather / all-reduce calls that DCP uses internally. Rules:
Non-PP: use the default global group (
process_group=None). Every rank holds a slice of the same state.PP: use a per-stage group. Stage
i’s ranks (all DP × TP ranks at PP positioni) coordinate alone — they produce one DCP shard set underpp{i}/.
CheckpointManager stores the group at construction time and passes
it through on every save/load.
Loading¶
Load is the mirror image of save:
dcp_state = {
"model": get_model_state_dict(self.model),
"optimizer": get_optimizer_state_dict(self.model, self.optimizer),
}
dcp.load(dcp_state, checkpoint_id=str(dcp_dir), process_group=self._process_group)
set_model_state_dict(self.model, dcp_state["model"])
set_optimizer_state_dict(self.model, self.optimizer, optim_state_dict=dcp_state["optimizer"])
The getter call gives DCP the shape to fill — it doesn’t contain
the saved data, just the tensor metadata (and, crucially, the
allocated optimizer moment tensors) DCP needs to know what to load
where. dcp.load mutates those tensors in place with the loaded
values; the setters then write them back into the live model and
optimizer.
Why the DCP-aware helpers, not
optimizer.state_dict()? On resume the optimizer is freshly built, sooptimizer.state_dict()has noexp_avg/exp_avg_sqtensors yet — AdamW creates the per-parameter state lazily on the first.step(). Passing that empty dict as the load template givesdcp.loadnothing to fill, so the saved moments are silently dropped and Adam momentum resets to zero at every resume (a non-bit-exact resume).get_optimizer_state_dictallocates the moment tensors up front in the right sharded layout, sodcp.loadrepopulates them. The model side would work with either call — its parameters are always allocated — but we use the matching getter/setter for symmetry.
Loading with a different GPU count triggers DCP’s automatic resharding — see Resharding.
To skip loading the optimizer (e.g. for fine-tuning), pass
exclude_keys=["optimizer"]:
ckpt_mgr.load(path=..., exclude_keys=["optimizer"]) # scripts/eval.py does this
See also¶
Resharding — the save-at-N, load-at-M mechanics that make DCP worth the extra files.
Train state — what else is in each checkpoint directory.
Auto-resume — how KempnerForge finds the right
step_Non restart.HF conversion — exporting DCP shards to HuggingFace safetensors.
Configuration § CheckpointConfig —
async_mode,interval,keep_last_n,dir.