Mix datasets and anneal data weights¶
Training on more than one corpus at once — and shifting the mix mid-run
— is common practice. KempnerForge supports both through three
orthogonal config controls: multi-dataset sources in [data], a
temperature knob on the mixture weights, and optional [data.phases]
that retarget the mix (and the LR) at fixed steps. This page explains
each and shows how they compose.
The three controls¶
[data]
mix_temperature = 1.0 # knob on the weights; 1.0 = as-declared
[[data.datasets]] # multi-dataset mixture
name = "web"
path = "/data/tokenized/web"
weight = 0.7
[[data.datasets]]
name = "code"
hf_name = "bigcode/the-stack-dedup"
weight = 0.3
[[data.phases]] # phase transitions
start_step = 10_000
dataset_weights = { web = 0.3, code = 0.7 }
lr_scale = 0.5
[[data.datasets]]turns mixing on. Without it,[data]falls back to the single-dataset fields (dataset_pathorhf_dataset_name).mix_temperaturerescales declared weights before sampling.[[data.phases]]swaps weights (and scales LR) at specific steps during training.
A fourth control, anneal_start_step / anneal_weights, is sugar for
the common “change the mix once, late in training” pattern — see
Annealing shortcut below.
Multi-dataset mixture¶
Each [[data.datasets]] block is a
DatasetSource:
Field |
Meaning |
|---|---|
|
Label for per-dataset metrics and phase overrides. Auto-filled from |
|
Pre-tokenized directory (use with |
|
HuggingFace dataset ID (alternative to |
|
HF dataset config (e.g. |
|
Relative sampling weight (must be positive). Normalized internally. |
At least one of path / hf_name must be set per source, enforced by
DataConfig.__post_init__.
When any [[data.datasets]] is present, scripts/train.py builds a
MixtureDataset
over the sub-datasets and drives it with a
MixtureSampler.
The sampler:
Partitions each sub-dataset’s indices across data-parallel ranks (stride-based, like
DistributedSampler).Allocates
target_counts[i] = round(prob[i] × total_per_rank)indices per epoch from each source and over- or undersamples to match the target.Interleaves the drawn indices with one final shuffle so the model sees a randomly mixed order, not source-blocked batches.
Every sample returned by MixtureDataset.__getitem__ includes a
dataset_idx key so the training loop can slot the per-batch loss into
per-dataset metrics.
Per-dataset metrics¶
When the mixture is active and the metrics interval fires,
scripts/train.py emits two series per dataset name:
loss/{name}— mean loss of samples from that dataset in the accumulation window.data/{name}/tokens— running token count consumed from that dataset.
Plot these in WandB / TensorBoard to see whether a dataset is
contributing normally or drifting. A rising loss/code while
loss/web stays flat is a signal.
Temperature¶
mix_temperature rescales the declared weight values before
normalization. The math in
MixtureSampler.__init__
is:
log_w = [log(max(w, 1e-12)) / temperature for w in weights]
probs = softmax(log_w) # after subtracting the max for stability
So the three interesting regimes are:
temperature = 1.0(default) — probabilities are just the declared weights, normalized.temperature > 1.0— weights are flattened toward uniform. Attemperature → ∞, every source has probability1/Nregardless of its declared weight.temperature < 1.0— weights are sharpened. Attemperature → 0, the heaviest source takes everything.
The common setting is temperature > 1 when the declared weights
reflect corpus size but you want to undersample the largest corpus so
a small, high-quality source isn’t drowned out. A typical value is in
the 1.3–2.0 range; the right number depends on the relative size of
your corpora.
Phase transitions¶
A [[data.phases]] entry is a
TrainingPhase:
Field |
Meaning |
|---|---|
|
Step at which this phase activates (must be non-negative). |
|
|
|
Multiplier applied to every param group’s LR once this phase fires. |
Constraints (validated in DataConfig.__post_init__):
Phase
start_stepvalues must be strictly monotonically increasing.You can use either
[[data.phases]]or the annealing shortcut below — not both.All
dataset_weightsvalues must be non-negative;lr_scale > 0.
What happens on a phase transition¶
On the first training step where step >= phase.start_step, the loop
in
scripts/train.py
does two things:
Builds a new
weightslist by overriding the original declared weights withphase.dataset_weightsentries where specified, then callssampler.update_weights(new_weights, temperature=config.data.mix_temperature)— the next__iter__()call on the sampler uses the new mix.Sets
phase_lr_scale = phase.lr_scale. From this step forward, the training loop multiplies every parameter group’s LR byphase_lr_scaleafter the scheduler has computed the base LR. The scheduler still runs;phase_lr_scaleis an additional factor on top.
So if your scheduler is driving lr from 3e-4 down to 3e-5 via
cosine decay, a phase with lr_scale = 0.5 means the optimizer sees
1.5e-4 → 1.5e-5 while that phase is active. Later phases overwrite
phase_lr_scale with their own value; there’s no compounding across
phases.
Resume behavior¶
At startup, the training loop walks the phase list and applies every
phase whose start_step <= current_step, logging
"Resumed into phase K, lr_scale=S". The checkpoint itself tracks
the current phase index (stored in the ckpt_extra field), so a
mid-phase resume lands on the correct weights without re-firing older
phases.
mixture_dataset.dataset_names is the key into phase.dataset_weights
— keep name stable across checkpoint / resume or you’ll silently
revert to the original declared weights for any name that no longer
matches.
Annealing shortcut¶
For the very common “one change, late” pattern, you can skip the
verbose [[data.phases]] block:
[data]
anneal_start_step = 40_000
anneal_weights = { web = 0.1, code = 0.9 }
scripts/train.py converts this into a one-element TrainingPhase
list internally:
TrainingPhase(
start_step=config.data.anneal_start_step,
dataset_weights=dict(config.data.anneal_weights),
# lr_scale defaults to 1.0 — no LR change
)
Default lr_scale is 1.0 — the LR curve doesn’t change unless you
say so. If you want both weight annealing and an LR drop on
transition, use [[data.phases]] explicitly.
Sanity checks¶
Before committing to a long run, verify the mix is what you think:
uv run python scripts/train.py configs/train/your_mix.toml \
--train.max_steps=100 --metrics.log_interval=1
Inspect the data/{name}/tokens counters after 100 steps. The ratios
should roughly match your target probabilities.
data/web/tokens ≈ (web_prob) × total_tokens
data/code/tokens ≈ (code_prob) × total_tokens
If the numbers are off by much more than one batch’s worth of tokens, double-check:
namevalues in[[data.datasets]]vs anyphase.dataset_weightsoverrides (case-sensitive, exact match).That phase
start_stepvalues haven’t already fired before step 100 (phase transitions change what you’re measuring).That
mix_temperatureis what you intended — a temperature other than 1.0 changes sampling probabilities away from the declared weights.
For a phase transition specifically, watch the log for the
"Phase transition at step N: phase=K, lr_scale=S" line — it fires
exactly once per transition, and the data/{name}/tokens slopes
should visibly change immediately after.
See also¶
Data § Mixing and annealing —
MixtureDatasetinternals and the non-mixing dataset classes.Data § Sampler —
MixtureSamplerandupdate_weightsinternals.Configuration §
[data]— everyDataConfigfield and its validation rule.Training § Training loop — where
phase_lr_scaleis applied and how phase checkpoint state is persisted.Compare optimizers — LR conventions matter when you combine
lr_scalewith non-AdamW optimizers.Prepare tokenized data — how to produce the per-source
pathdirectories this page consumes.