kempnerforge.training.optimizer¶
Optimizer construction for KempnerForge.
- Builds optimizers with per-parameter-group settings:
AdamW: standard Adam with decoupled weight decay
Muon: momentum with orthogonalized updates via Newton-Schulz iteration. Applies Muon to 2D+ weight matrices, AdamW to 1D params (biases, norms).
Lion: sign-based momentum update (half the optimizer memory of AdamW)
Schedule-Free AdamW: eliminates LR schedule by averaging iterates
- All optimizers:
Weight decay applied to matrix weights only
Bias and norm parameters excluded from weight decay
Fused kernel when available (PyTorch 2.x, AdamW only)
Functions
|
Construct an optimizer with per-parameter-group weight decay settings. |
Classes
Lion optimizer (Chen et al., 2023): Evolved Sign Momentum. |
|
Muon optimizer: Momentum with Orthogonalized Updates. |
|
Schedule-Free AdamW (Defazio & Mishchenko, 2024). |
- class kempnerforge.training.optimizer.Lion[source]¶
Bases:
OptimizerLion optimizer (Chen et al., 2023): Evolved Sign Momentum.
Uses sign-based updates with momentum interpolation. Only maintains one momentum buffer (vs two for AdamW), halving optimizer memory.
Update rule:
update = sign(beta1 * m + (1 - beta1) * grad) m = beta2 * m + (1 - beta2) * grad p = p * (1 - lr * wd) - lr * update
- Parameters:
params – Parameters or parameter groups.
lr – Learning rate (typically 3-10x smaller than AdamW).
betas –
(beta1, beta2)for update interpolation and momentum.weight_decay – Decoupled weight decay coefficient.
- step(closure=None)¶
- class kempnerforge.training.optimizer.ScheduleFreeAdamW[source]¶
Bases:
OptimizerSchedule-Free AdamW (Defazio & Mishchenko, 2024).
Eliminates the need for an LR schedule by maintaining an iterate
zand a running averagex. Parameters are set to the interpolated pointy = (1 - beta1) * z + beta1 * xfor gradient computation.Use with
scheduler.name = "none"— no LR schedule is needed.- Parameters:
params – Parameters or parameter groups.
lr – Learning rate (constant — no schedule needed).
betas –
(beta1, beta2)for interpolation and second moment.eps – Denominator term for numerical stability.
weight_decay – Decoupled weight decay.
warmup_steps – Linear warmup steps (internal to the optimizer).
- __init__(params, lr=0.025, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, warmup_steps=0)[source]¶
- step(closure=None)¶
- eval_params()[source]¶
Set parameters to the evaluation point (running average).
Call before validation/inference for best results. Call
train_params()afterward to resume training.- Return type:
None
- train_params()[source]¶
Restore parameters to the training point (interpolated y).
Call after
eval_params()to resume training.- Return type:
None
- class kempnerforge.training.optimizer.Muon[source]¶
Bases:
OptimizerMuon optimizer: Momentum with Orthogonalized Updates.
For 2D+ weight matrices: maintains momentum, then orthogonalizes the update direction via Newton-Schulz iteration. This keeps update directions independent of parameter scale.
For 1D parameters (biases, norms, embeddings): uses standard AdamW, since orthogonalization requires 2D matrices.
FSDP2 note: Newton-Schulz operates on each rank’s local shard independently — an approximation, not mathematically equivalent to orthogonalizing the full weight matrix. This is the standard approach for distributed Muon and works well in practice.
- Parameters:
muon_params – Parameter groups for Muon (2D+ weights).
adam_params – Parameter groups for AdamW fallback (1D params).
lr – Learning rate for Muon (2D weights).
momentum – Momentum coefficient (default 0.95).
weight_decay – Decoupled weight decay.
adam_betas – Betas for the AdamW fallback.
adam_eps – Epsilon for the AdamW fallback.
ns_steps – Newton-Schulz iteration steps (default 5).
adam_lr – Learning rate for AdamW fallback (1D params). None = same as lr.
- __init__(muon_params, adam_params, lr=0.02, momentum=0.95, weight_decay=0.0, adam_betas=(0.9, 0.95), adam_eps=1e-08, ns_steps=5, adam_lr=None)[source]¶
- load_state_dict(state_dict)[source]¶
Restore internal AdamW state from checkpoint.
- Parameters:
state_dict (dict)
- Return type:
None
- step(closure=None)¶
- kempnerforge.training.optimizer.build_optimizer(model, config)[source]¶
Construct an optimizer with per-parameter-group weight decay settings.
- Parameters:
model (torch.nn.Module) – Model whose parameters to optimize.
config (OptimizerConfig) – Optimizer configuration.
- Returns:
Configured optimizer instance.
- Return type: