kempnerforge.training.optimizer

Optimizer construction for KempnerForge.

Builds optimizers with per-parameter-group settings:
  • AdamW: standard Adam with decoupled weight decay

  • Muon: momentum with orthogonalized updates via Newton-Schulz iteration. Applies Muon to 2D+ weight matrices, AdamW to 1D params (biases, norms).

  • Lion: sign-based momentum update (half the optimizer memory of AdamW)

  • Schedule-Free AdamW: eliminates LR schedule by averaging iterates

All optimizers:
  • Weight decay applied to matrix weights only

  • Bias and norm parameters excluded from weight decay

  • Fused kernel when available (PyTorch 2.x, AdamW only)

Functions

build_optimizer(model, config)

Construct an optimizer with per-parameter-group weight decay settings.

Classes

Lion

Lion optimizer (Chen et al., 2023): Evolved Sign Momentum.

Muon

Muon optimizer: Momentum with Orthogonalized Updates.

ScheduleFreeAdamW

Schedule-Free AdamW (Defazio & Mishchenko, 2024).

class kempnerforge.training.optimizer.Lion[source]

Bases: Optimizer

Lion optimizer (Chen et al., 2023): Evolved Sign Momentum.

Uses sign-based updates with momentum interpolation. Only maintains one momentum buffer (vs two for AdamW), halving optimizer memory.

Update rule:

update = sign(beta1 * m + (1 - beta1) * grad)
m = beta2 * m + (1 - beta2) * grad
p = p * (1 - lr * wd) - lr * update
Parameters:
  • params – Parameters or parameter groups.

  • lr – Learning rate (typically 3-10x smaller than AdamW).

  • betas(beta1, beta2) for update interpolation and momentum.

  • weight_decay – Decoupled weight decay coefficient.

__init__(params, lr=0.0001, betas=(0.9, 0.99), weight_decay=0.0)[source]
Parameters:
Return type:

None

step(closure=None)
class kempnerforge.training.optimizer.ScheduleFreeAdamW[source]

Bases: Optimizer

Schedule-Free AdamW (Defazio & Mishchenko, 2024).

Eliminates the need for an LR schedule by maintaining an iterate z and a running average x. Parameters are set to the interpolated point y = (1 - beta1) * z + beta1 * x for gradient computation.

Use with scheduler.name = "none" — no LR schedule is needed.

Parameters:
  • params – Parameters or parameter groups.

  • lr – Learning rate (constant — no schedule needed).

  • betas(beta1, beta2) for interpolation and second moment.

  • eps – Denominator term for numerical stability.

  • weight_decay – Decoupled weight decay.

  • warmup_steps – Linear warmup steps (internal to the optimizer).

__init__(params, lr=0.025, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, warmup_steps=0)[source]
Parameters:
Return type:

None

state_dict()[source]
Return type:

dict

load_state_dict(state_dict)[source]
Parameters:

state_dict (dict)

Return type:

None

step(closure=None)
eval_params()[source]

Set parameters to the evaluation point (running average).

Call before validation/inference for best results. Call train_params() afterward to resume training.

Return type:

None

train_params()[source]

Restore parameters to the training point (interpolated y).

Call after eval_params() to resume training.

Return type:

None

class kempnerforge.training.optimizer.Muon[source]

Bases: Optimizer

Muon optimizer: Momentum with Orthogonalized Updates.

For 2D+ weight matrices: maintains momentum, then orthogonalizes the update direction via Newton-Schulz iteration. This keeps update directions independent of parameter scale.

For 1D parameters (biases, norms, embeddings): uses standard AdamW, since orthogonalization requires 2D matrices.

FSDP2 note: Newton-Schulz operates on each rank’s local shard independently — an approximation, not mathematically equivalent to orthogonalizing the full weight matrix. This is the standard approach for distributed Muon and works well in practice.

Parameters:
  • muon_params – Parameter groups for Muon (2D+ weights).

  • adam_params – Parameter groups for AdamW fallback (1D params).

  • lr – Learning rate for Muon (2D weights).

  • momentum – Momentum coefficient (default 0.95).

  • weight_decay – Decoupled weight decay.

  • adam_betas – Betas for the AdamW fallback.

  • adam_eps – Epsilon for the AdamW fallback.

  • ns_steps – Newton-Schulz iteration steps (default 5).

  • adam_lr – Learning rate for AdamW fallback (1D params). None = same as lr.

__init__(muon_params, adam_params, lr=0.02, momentum=0.95, weight_decay=0.0, adam_betas=(0.9, 0.95), adam_eps=1e-08, ns_steps=5, adam_lr=None)[source]
Parameters:
state_dict()[source]

Include internal AdamW state so DCP checkpoints are complete.

Return type:

dict

load_state_dict(state_dict)[source]

Restore internal AdamW state from checkpoint.

Parameters:

state_dict (dict)

Return type:

None

step(closure=None)
kempnerforge.training.optimizer.build_optimizer(model, config)[source]

Construct an optimizer with per-parameter-group weight decay settings.

Parameters:
Returns:

Configured optimizer instance.

Return type:

torch.optim.Optimizer