Optimizers¶
Four optimizers are registered in
kempnerforge/training/optimizer.py:
adamw, lion, muon, schedule_free_adamw. All are constructed by
build_optimizer(model, config), which routes by config.name through
the registry.
adamw¶
[optimizer]
name = "adamw"
lr = 3e-4
betas = [0.9, 0.95]
eps = 1e-8
weight_decay = 0.1
fused = true
Direct wrapper of torch.optim.AdamW. fused = true enables the
fused CUDA kernel on PyTorch 2.x + CUDA; falls back to foreach
implementation otherwise. The fused kernel is the default in every
shipped dense config.
Memory: 2 optimizer-state tensors per parameter (exp_avg, exp_avg_sq) in fp32 — the master-weight budget most estimators assume.
lion¶
[optimizer]
name = "lion"
lr = 1e-4
betas = [0.9, 0.99]
weight_decay = 0.01
Sign-momentum optimizer. The update rule is
update = sign(β₁·m + (1-β₁)·g), with one momentum buffer per
parameter — roughly half AdamW’s optimizer-state memory.
Notes from the code:
LR is ~3-10× smaller than AdamW’s — Lion’s sign-valued update is O(1) per coordinate regardless of gradient scale.
Decoupled weight decay:
p.data.mul_(1 - lr * wd)before the sign-update step (no interaction with momentum).No
eps— Lion’s update never divides by a state tensor.
schedule_free_adamw¶
[optimizer]
name = "schedule_free_adamw"
lr = 0.025
betas = [0.9, 0.999]
eps = 1e-8
schedule_free_warmup_steps = 2000
[scheduler]
name = "none" # required
Defazio’s Schedule-Free AdamW. Replaces the external LR schedule with
an internal Polyak-averaging trick that tracks an iterate z, a
running average x, and a weight sum. The TOML-visible scheduler.name = "none" is required — adding an external cosine or linear schedule
breaks the internal averaging.
Gotchas:
Eval mode switch: call
optimizer.eval_params()before validation andoptimizer.train_params()after, to swap in the averaged weightsxand restore the iteratez. This is not automatic in the current loop.Internal warmup:
schedule_free_warmup_stepsis the optimizer’s own linear warmup, independent of any external scheduler. Zero disables it.LR is 10-100× larger than AdamW’s — the Polyak averaging absorbs aggressive step sizes.
muon¶
[optimizer]
name = "muon"
lr = 0.02
muon_momentum = 0.95
muon_ns_steps = 5
# muon_adam_lr omitted -> use same LR as Muon for 1D params
weight_decay = 0.1
betas = [0.9, 0.95] # used by internal AdamW for 1D params
eps = 1e-8
Keller Jordan’s Muon. Applies Newton-Schulz orthogonalization to the momentum buffer of each 2D weight before stepping; 1D parameters (biases, norms, highly rectangular matrices) fall through to an internal AdamW with the same hyperparameters.
How it decides which path each parameter takes
(_is_muon_eligible):
2D weight, aspect ratio ≤ 10 → Muon orthogonalized update.
1D, or ratio > 10 → internal AdamW fallback.
DTensor / FSDP2 notes:
Muon’s orthogonalization runs on the local shard, not the full weight. It calls
_get_local_tensor(p)to unwrap DTensors before Newton-Schulz. Empirically fine in practice; not mathematically identical to full-matrix orthogonalization.Momentum buffers match the parameter dtype (
torch.zeros_like(p.grad)), so DCP serializes them correctly on resume.muon_adam_lrunset (default) makes the internal AdamW use the same LR as Muon. Pass a smaller value (e.g.muon_adam_lr = 1e-4) if you want 1D params on a gentler schedule.
Newton-Schulz coefficients (hard-coded): a = 3.4445, b = -4.7750, c = 2.0315, Frobenius-normalized input — standard values from the Muon
paper.
The shipped
7b_16gpu_muon.toml
combines Muon with z-loss + chunked cross-entropy as an integration
test of all three features together.
Picking one¶
Situation |
Pick |
|---|---|
Default, known-good |
|
Half the optimizer-state memory, accept LR re-tuning |
|
Want to skip scheduler tuning |
|
Research recipe for large dense runs |
|
For any optimizer, the decay-vs-no-decay grouping is handled
automatically — no need to thread no_decay_params through your config.
See also¶
Schedulers — the
scheduler.*side, including"none"for schedule-free.Configuration § OptimizerConfig — every field and its default.
Training loop § Optimizer step — where
optimizer.step()fires and how phase LR scaling layers on after.