Data Flow¶
The path a batch of tokens takes from the dataloader to a committed
gradient update. Follow this page with
scripts/train.py
open in another tab — line numbers move, but the structure below maps
one-to-one onto the training loop body.
One-slide view¶
┌─ MemoryMappedDataset or MixtureDataset
│
├─ DistributedSampler / MixtureSampler (rank-partitioned indices)
│
├─ StatefulDataLoader ──► batch = {input_ids, labels, doc_ids?}
│
▼
for micro_step in range(grad_accum_steps): ── maybe_no_sync sets
├─ model.set_moe_step(step, max_steps) DDP/FSDP grad-sync off
├─ logits = model(input_ids, doc_ids=doc_ids) on all but last micro
├─ loss = loss_fn(logits, labels)
├─ [+ moe_aux_loss_weight * model.get_moe_aux_loss()]
└─ (loss / grad_accum_steps).backward() ── gradient accumulation
grad_norm = clip_grad_norm_(model, grad_clip_norm)
if NaN: zero_grad, skip, maybe stop
optimizer.step() ; scheduler.step() ; optimizer.zero_grad()
tracker.end_step(step, loss, grad_norm, lr, tokens_in_step)
hook_runner.on_step_end(StepContext(...))
[every N steps] eval, NCCL health, profiler.step, ckpt_mgr.save
if shutdown_handler.should_shutdown(): save + break
Startup, once¶
Before the loop starts,
scripts/train.py
initializes the collaborators:
init_distributed(config.distributed, seed=...)fromkempnerforge.distributed.setup— readsRANK/LOCAL_RANK/WORLD_SIZE(torchrun) orSLURM_PROCID/SLURM_NTASKS(multi-node srun), callsdist.init_process_group, builds theDeviceMesh, seeds torch.ShutdownHandlerfromkempnerforge.resilience.signal_handler— installs SIGTERM / SIGUSR1 handlers for cooperative SLURM preemption.NaNDetectorfromkempnerforge.resilience.health— tracks consecutive NaN steps;action="warn"by default, escalates to rollback aftermax_consecutive=10.Loss function from the registry (
"cross_entropy"or"chunked_cross_entropy").Model via
build_parallel_model— applies the full parallelism stack (see Parallelism order).Optimizer and scheduler from their registries.
CheckpointManagerfromkempnerforge.checkpoint.manager— owns DCP async save/load, latest-symlink, step_N directories.resolve_resume_path(config.checkpoint.dir)fromkempnerforge.resilience.elastic— follows thelatestsymlink or picks the higheststep_N. If non-None,ckpt_mgr.load(resume_path)restores model, optimizer, scheduler, dataloader position, and RNG state before the loop starts.MetricsTrackerfromkempnerforge.metrics.tracker— per-step metrics + EMA smoothing + WandB / TensorBoard backends.Data pipeline —
MemoryMappedDatasetorMixtureDatasetor an HF streaming / eager dataset (config dispatches), wrapped byDistributedSamplerorMixtureSamplerand then byStatefulDataLoaderfromkempnerforge.data.Eval dataloader and
torch.profiler(optional).Phase schedule (optional) — for curriculum training, rebalances the mixture at
phase.start_step.
Inside the loop: one step¶
The step body is where the interesting routing happens.
1 · Microbatch fetch¶
for micro_step in range(grad_accum_steps):
batch = next(data_iter)
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
doc_ids = batch.get("doc_ids").to(device) if "doc_ids" in batch else None
When dataloader is None, the loop generates random integer tokens —
useful for smoke-testing the parallelism stack without any corpus.
doc_ids is optional: non-None only when the dataset packs multiple
documents into one sequence. It triggers the block-diagonal attention
mask path (see Model § Three attention paths).
2 · maybe_no_sync¶
with maybe_no_sync(model, micro_step, grad_accum_steps):
...
Utility from
kempnerforge.training.grad,
re-exported from kempnerforge.training. On all microbatches except
the last, it disables gradient synchronization so the backward pass
accumulates locally and does not fire reduce-scatter N times per
optimizer step.
3 · MoE step propagation¶
if mc.is_moe:
model.set_moe_step(step, tc.max_steps)
Forwards (step, max_steps) to every SigmoidTopKRouter. Used by the
adaptive bias schedule (see docs/moe/). Dense models
skip this.
4 · Forward¶
logits = model(input_ids, doc_ids=doc_ids)
loss = loss_fn(logits, labels)
The forward pass follows the model page: token embedding →
N transformer blocks (RoPE + GQA + SwiGLU or MoE) → final RMSNorm →
output head → (batch, seq_len, vocab_size) logits.
Per-dataset metrics, if the dataloader is a mixture, are collected here before backward so the logits are still alive.
5 · MoE auxiliary loss¶
if mc.is_moe:
aux_loss = model.get_moe_aux_loss()
loss = loss + mc.moe_aux_loss_weight * aux_loss
get_moe_aux_loss() sums the per-layer MoEMLP.aux_loss attributes.
For dense runs it returns 0.0 and the line is a no-op.
6 · Backward with gradient accumulation¶
scaled_loss = loss / tc.grad_accum_steps
scaled_loss.backward()
total_loss += loss.item()
Scaling by grad_accum_steps keeps the effective learning rate
invariant to the accumulation factor.
7 · Clip, NaN check, optimizer step¶
After the microbatch loop:
grad_norm = clip_grad_norm_(model, tc.grad_clip_norm)
if not nan_detector.check_loss(avg_loss, step):
optimizer.zero_grad(); step += 1; continue # skip this step
optimizer.step()
scheduler.step()
if phase_lr_scale != 1.0:
for pg in optimizer.param_groups: pg["lr"] *= phase_lr_scale
optimizer.zero_grad()
clip_grad_norm_ wraps PyTorch’s utility so it works with FSDP2 sharded
parameters. The NaN detector returns False on NaN/Inf loss, zeroes
the grads, and (after max_consecutive) signals a rollback to the
previous checkpoint.
Phase LR scaling applies after the scheduler — it multiplies the base
LR that scheduler.step() just computed.
8 · Step accounting¶
step += 1
tokens_in_step = tc.batch_size * tc.seq_len * tc.grad_accum_steps * dp_size
tokens_seen += tokens_in_step
tokens_in_step times all ranks in the data-parallel dimension, not the
full world_size — TP, PP, and EP don’t multiply unique tokens.
9 · Phase transitions¶
If a mixture phase is due, sampler.update_weights(...) rebalances the
MixtureSampler and data_iter = None forces a refresh next microbatch
so the new weights take effect immediately.
10 · Metrics and hooks¶
step_metrics = tracker.end_step(step=step, loss=avg_loss,
grad_norm=grad_norm_val, lr=current_lr,
tokens_in_step=tokens_in_step)
hook_runner.on_step_end(StepContext(...))
tracker.end_step dispatches to WandB / TensorBoard at
metrics.log_interval. HookRunner (from
kempnerforge.training.hooks)
runs user-defined TrainingHook callbacks each step.
MoE runs additionally log moe/aux_loss and moe/expert_balance
(min/max of per-expert token counts) at the same cadence.
11 · Periodic work¶
In order:
NCCL health (every
tc.nccl_health_check_intervalsteps) — fires a small all-reduce; on failure, break the loop.Eval (every
eval_config.intervalsteps) — runsrun_evalon the eval dataloader, logs metrics, fireson_eval_endhooks.Profiler —
prof.step()advances thetorch.profilerschedule.Checkpoint (every
checkpoint.intervalsteps) —ckpt_mgr.savewrites a DCP checkpoint asynchronously and updates thelatestsymlink.Graceful shutdown — if SIGTERM/SIGUSR1 fired,
ckpt_mgr.save(emergency)and exit.
Shutdown¶
After the loop exits:
prof.stop()flushes traces.ckpt_mgr.wait()drains the last async save.hook_runner.on_train_end(step, tokens_seen).tracker.close()flushes WandB / TB.destroy_distributed()tears down the process group.
Where to read next¶
Training subsystem — loss functions, optimizers, schedulers in detail.
Checkpointing — DCP internals and the resume protocol.
Resilience — SIGTERM, NaN, NCCL health mechanics.
Metrics and profiling —
MetricsTracker, MFU, profiler.