"""Optimizer construction for KempnerForge.
Builds optimizers with per-parameter-group settings:
- AdamW: standard Adam with decoupled weight decay
- Muon: momentum with orthogonalized updates via Newton-Schulz iteration.
Applies Muon to 2D+ weight matrices, AdamW to 1D params (biases, norms).
- Lion: sign-based momentum update (half the optimizer memory of AdamW)
- Schedule-Free AdamW: eliminates LR schedule by averaging iterates
All optimizers:
- Weight decay applied to matrix weights only
- Bias and norm parameters excluded from weight decay
- Fused kernel when available (PyTorch 2.x, AdamW only)
"""
from __future__ import annotations
import logging
import torch
from kempnerforge.config.registry import registry
from kempnerforge.config.schema import OptimizerConfig
logger = logging.getLogger(__name__)
@registry.register_optimizer("adamw")
def _build_adamw(
param_groups: list[dict],
config: OptimizerConfig,
) -> torch.optim.Optimizer:
return torch.optim.AdamW(
param_groups,
lr=config.lr,
betas=config.betas,
eps=config.eps,
fused=config.fused and torch.cuda.is_available(),
)
# ---------------------------------------------------------------------------
# Lion optimizer
# ---------------------------------------------------------------------------
[docs]
class Lion(torch.optim.Optimizer):
"""Lion optimizer (Chen et al., 2023): Evolved Sign Momentum.
Uses sign-based updates with momentum interpolation. Only maintains
one momentum buffer (vs two for AdamW), halving optimizer memory.
Update rule::
update = sign(beta1 * m + (1 - beta1) * grad)
m = beta2 * m + (1 - beta2) * grad
p = p * (1 - lr * wd) - lr * update
Args:
params: Parameters or parameter groups.
lr: Learning rate (typically 3-10x smaller than AdamW).
betas: ``(beta1, beta2)`` for update interpolation and momentum.
weight_decay: Decoupled weight decay coefficient.
"""
[docs]
def __init__(
self,
params,
lr: float = 1e-4,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
) -> None:
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
beta1, beta2 = group["betas"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p.data)
m = state["exp_avg"]
# Update direction: sign of interpolated momentum + gradient
# .mul() (not in-place) creates a temporary — m is unchanged
update = (m.mul(beta1) + grad.mul(1 - beta1)).sign_()
# Decoupled weight decay
if wd > 0:
p.data.mul_(1 - lr * wd)
# Apply update
p.data.add_(update, alpha=-lr)
# Update momentum buffer
m.mul_(beta2).add_(grad, alpha=1 - beta2)
return loss
@registry.register_optimizer("lion")
def _build_lion(
param_groups: list[dict],
config: OptimizerConfig,
) -> torch.optim.Optimizer:
return Lion(
param_groups,
lr=config.lr,
betas=config.betas,
)
# ---------------------------------------------------------------------------
# Schedule-Free AdamW
# ---------------------------------------------------------------------------
[docs]
class ScheduleFreeAdamW(torch.optim.Optimizer):
"""Schedule-Free AdamW (Defazio & Mishchenko, 2024).
Eliminates the need for an LR schedule by maintaining an iterate ``z``
and a running average ``x``. Parameters are set to the interpolated
point ``y = (1 - beta1) * z + beta1 * x`` for gradient computation.
Use with ``scheduler.name = "none"`` — no LR schedule is needed.
Args:
params: Parameters or parameter groups.
lr: Learning rate (constant — no schedule needed).
betas: ``(beta1, beta2)`` for interpolation and second moment.
eps: Denominator term for numerical stability.
weight_decay: Decoupled weight decay.
warmup_steps: Linear warmup steps (internal to the optimizer).
"""
[docs]
def __init__(
self,
params,
lr: float = 0.025,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
warmup_steps: int = 0,
) -> None:
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
self.warmup_steps = warmup_steps
self._k = 0
[docs]
def state_dict(self) -> dict:
sd = super().state_dict()
sd["_k"] = self._k
sd["_warmup_steps"] = self.warmup_steps
return sd
[docs]
def load_state_dict(self, state_dict: dict) -> None:
self._k = state_dict.pop("_k", 0)
self.warmup_steps = state_dict.pop("_warmup_steps", self.warmup_steps)
super().load_state_dict(state_dict)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
self._k += 1
k = self._k
warmup = min(1.0, k / max(self.warmup_steps, 1)) if self.warmup_steps > 0 else 1.0
for group in self.param_groups:
lr = group["lr"] * warmup
beta1, beta2 = group["betas"]
eps = group["eps"]
wd = group["weight_decay"]
# Weight for Polyak averaging (accounts for variable LR during warmup)
weight = lr * (1 - beta1)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if len(state) == 0:
state["z"] = p.data.clone()
state["v"] = torch.zeros_like(p.data)
state["weight_sum"] = 0.0
z = state["z"]
v = state["v"]
# Second moment update
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Bias-corrected denominator
bc2 = 1 - beta2**k
denom = (v / bc2).sqrt().add_(eps)
# Update z (the iterate)
z.addcdiv_(grad, denom, value=-lr)
# Decoupled weight decay on z
if wd > 0:
z.mul_(1 - lr * wd)
# Polyak average coefficient
state["weight_sum"] += weight
ck = weight / state["weight_sum"]
# Update p.data via interpolation:
# x_new = (1 - ck) * x_old + ck * z
# y_new = (1 - beta1) * z + beta1 * x_new
# We don't store x explicitly — derive it from current y and z:
# x_old = (y_old - (1 - beta1) * z_old) / beta1
# But z changed, so we need a different approach.
# Store x explicitly for correctness:
if "x" not in state:
state["x"] = z.clone()
else:
state["x"].lerp_(z, ck)
# Set params to interpolated point for next gradient computation
p.data.copy_(z).lerp_(state["x"], beta1)
return loss
[docs]
def eval_params(self) -> None:
"""Set parameters to the evaluation point (running average).
Call before validation/inference for best results.
Call :meth:`train_params` afterward to resume training.
"""
for group in self.param_groups:
for p in group["params"]:
if p in self.state and "x" in self.state[p]:
p.data.copy_(self.state[p]["x"])
[docs]
def train_params(self) -> None:
"""Restore parameters to the training point (interpolated y).
Call after :meth:`eval_params` to resume training.
"""
for group in self.param_groups:
beta1 = group["betas"][0]
for p in group["params"]:
state = self.state[p]
if "z" in state and "x" in state:
p.data.copy_(state["z"]).lerp_(state["x"], beta1)
@registry.register_optimizer("schedule_free_adamw")
def _build_schedule_free_adamw(
param_groups: list[dict],
config: OptimizerConfig,
) -> torch.optim.Optimizer:
return ScheduleFreeAdamW(
param_groups,
lr=config.lr,
betas=config.betas,
eps=config.eps,
warmup_steps=config.schedule_free_warmup_steps,
)
# ---------------------------------------------------------------------------
# Muon optimizer
# ---------------------------------------------------------------------------
def _newton_schulz(G: torch.Tensor, steps: int = 5) -> torch.Tensor:
"""Approximate orthogonal projection via Newton-Schulz iteration.
Given a matrix G, computes the nearest orthogonal matrix U such that
U^T U ≈ I. This is used to orthogonalize the momentum update in Muon.
Uses a degree-5 polynomial iteration:
A = X @ X^T
B = b*A + c*A@A
X = a*X + B@X
Coefficients (a, b, c) are from Zhu & Jordan (2024), optimized for
convergence in 5 steps from a spectrally-normalized starting point.
Cost: ~15 FLOPs per parameter per step — negligible vs forward/backward.
"""
assert G.ndim == 2
a, b, c = 3.4445, -4.7750, 2.0315
# Frobenius normalization (matches reference Muon; sufficient for NS convergence)
X = G / (G.norm() + 1e-7)
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
return X
def _get_local_tensor(t: torch.Tensor) -> torch.Tensor:
"""Extract the local (non-DTensor) tensor from a possibly-sharded parameter.
FSDP2 wraps parameters as DTensors. The optimizer needs to operate on
the underlying local shard directly to avoid DTensor/Tensor mixing errors
(e.g. 'aten.add_.Tensor got mixed torch.Tensor and DTensor').
"""
try:
from torch.distributed._tensor import DTensor
if isinstance(t, DTensor):
return t.to_local()
except ImportError:
pass
return t
[docs]
class Muon(torch.optim.Optimizer):
"""Muon optimizer: Momentum with Orthogonalized Updates.
For 2D+ weight matrices: maintains momentum, then orthogonalizes the
update direction via Newton-Schulz iteration. This keeps update
directions independent of parameter scale.
For 1D parameters (biases, norms, embeddings): uses standard AdamW,
since orthogonalization requires 2D matrices.
FSDP2 note: Newton-Schulz operates on each rank's local shard
independently — an approximation, not mathematically equivalent to
orthogonalizing the full weight matrix. This is the standard approach
for distributed Muon and works well in practice.
Args:
muon_params: Parameter groups for Muon (2D+ weights).
adam_params: Parameter groups for AdamW fallback (1D params).
lr: Learning rate for Muon (2D weights).
momentum: Momentum coefficient (default 0.95).
weight_decay: Decoupled weight decay.
adam_betas: Betas for the AdamW fallback.
adam_eps: Epsilon for the AdamW fallback.
ns_steps: Newton-Schulz iteration steps (default 5).
adam_lr: Learning rate for AdamW fallback (1D params). None = same as lr.
"""
[docs]
def __init__(
self,
muon_params: list[dict],
adam_params: list[dict],
lr: float = 0.02,
momentum: float = 0.95,
weight_decay: float = 0.0,
adam_betas: tuple[float, float] = (0.9, 0.95),
adam_eps: float = 1e-8,
ns_steps: int = 5,
adam_lr: float | None = None,
):
adam_lr = adam_lr if adam_lr is not None else lr
self._initial_lr = lr
self._initial_adam_lr = adam_lr
# Create internal AdamW for 1D params
self._adam = (
torch.optim.AdamW(
adam_params,
lr=adam_lr,
betas=adam_betas,
eps=adam_eps,
fused=torch.cuda.is_available(),
)
if any(len(g["params"]) > 0 for g in adam_params)
else None
)
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, ns_steps=ns_steps)
super().__init__(muon_params, defaults)
[docs]
def state_dict(self) -> dict:
"""Include internal AdamW state so DCP checkpoints are complete."""
sd = super().state_dict()
if self._adam is not None:
sd["_adam_state"] = self._adam.state_dict()
sd["_initial_lr"] = self._initial_lr
sd["_initial_adam_lr"] = self._initial_adam_lr
return sd
[docs]
def load_state_dict(self, state_dict: dict) -> None:
"""Restore internal AdamW state from checkpoint."""
adam_state = state_dict.pop("_adam_state", None)
self._initial_lr = state_dict.pop("_initial_lr", self._initial_lr)
self._initial_adam_lr = state_dict.pop("_initial_adam_lr", self._initial_adam_lr)
super().load_state_dict(state_dict)
if adam_state is not None and self._adam is not None:
self._adam.load_state_dict(adam_state)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
ns_steps = group["ns_steps"]
for p in group["params"]:
if p.grad is None:
continue
# Work on local tensors to avoid DTensor/Tensor mixing.
# FSDP2 wraps params as DTensors; operating on the local
# shard directly sidesteps distributed tensor dispatch.
p_local = _get_local_tensor(p)
g_local = _get_local_tensor(p.grad)
# Momentum buffer: stored matching p.grad's type (DTensor if
# FSDP2) so DCP can save/restore it correctly. We operate on
# the local shard for the actual computation.
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p.grad)
buf_local = _get_local_tensor(state["momentum_buffer"])
buf_local.mul_(momentum).add_(g_local)
# Reshape to 2D for Newton-Schulz if needed
original_shape = buf_local.shape
buf_2d = buf_local.view(buf_local.shape[0], -1) if buf_local.ndim > 2 else buf_local
# Orthogonalize
update = _newton_schulz(buf_2d.float(), steps=ns_steps)
update = update.to(buf_local.dtype)
if buf_local.ndim > 2:
update = update.view(original_shape)
# Scale: NS produces unit-norm rows, scale by sqrt(m/n)
m, n = buf_2d.shape
update.mul_(max(1, m / n) ** 0.5)
# Decoupled weight decay
if wd > 0:
p_local.mul_(1 - lr * wd)
p_local.add_(update, alpha=-lr)
# Step the internal AdamW for 1D params.
# Scale Adam LR proportionally to Muon's current LR so the scheduler's
# warmup/decay applies to both — the scheduler only sees Muon's param_groups.
if self._adam is not None:
if self._initial_lr > 0:
scale = self.param_groups[0]["lr"] / self._initial_lr
for group in self._adam.param_groups:
group["lr"] = self._initial_adam_lr * scale
self._adam.step()
return loss
def _is_muon_eligible(param: torch.Tensor) -> bool:
"""Check if a parameter should use Muon (Newton-Schulz orthogonalization).
Muon is applied to 2D weight matrices with reasonable aspect ratios.
Highly rectangular matrices (embeddings, output heads) are too expensive
for Newton-Schulz (X@X^T becomes vocab_size x vocab_size) and get AdamW.
"""
if param.ndim < 2:
return False
# Aspect ratio check: max/min > 10 means too rectangular for NS
m, n = param.shape[0], param.view(param.shape[0], -1).shape[1]
return max(m, n) / max(min(m, n), 1) <= 10
@registry.register_optimizer("muon")
def _build_muon(
param_groups: list[dict],
config: OptimizerConfig,
) -> torch.optim.Optimizer:
# Split: 2D params with reasonable aspect ratio get Muon, rest gets AdamW
muon_groups = []
adam_groups = []
for group in param_groups:
muon_params = []
adam_params = []
for p in group["params"]:
if _is_muon_eligible(p):
muon_params.append(p)
else:
adam_params.append(p)
if muon_params:
muon_groups.append({**group, "params": muon_params})
if adam_params:
adam_groups.append({**group, "params": adam_params})
if not adam_groups:
adam_groups = [{"params": [], "weight_decay": 0.0}]
n_muon = sum(p.numel() for g in muon_groups for p in g["params"])
n_adam = sum(p.numel() for g in adam_groups for p in g["params"])
logger.info(f"Muon: {n_muon:,} params (NS-orthogonalized), {n_adam:,} params (AdamW fallback)")
return Muon(
muon_groups,
adam_groups,
lr=config.lr,
momentum=config.muon_momentum,
weight_decay=config.weight_decay,
adam_betas=config.betas,
adam_eps=config.eps,
ns_steps=config.muon_ns_steps,
adam_lr=config.muon_adam_lr,
)
# ---------------------------------------------------------------------------
# Optimizer construction
# ---------------------------------------------------------------------------
def _should_decay(name: str, param: torch.nn.Parameter) -> bool:
"""Decide whether a parameter should receive weight decay.
Excluded: 1D parameters (biases, norm scales/shifts), embedding weights.
"""
if param.ndim <= 1:
return False
return "bias" not in name
[docs]
def build_optimizer(
model: torch.nn.Module,
config: OptimizerConfig,
) -> torch.optim.Optimizer:
"""Construct an optimizer with per-parameter-group weight decay settings.
Args:
model: Model whose parameters to optimize.
config: Optimizer configuration.
Returns:
Configured optimizer instance.
"""
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if _should_decay(name, param):
decay_params.append(param)
else:
no_decay_params.append(param)
param_groups = [
{"params": decay_params, "weight_decay": config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
# Log parameter counts
n_decay = sum(p.numel() for p in decay_params)
n_no_decay = sum(p.numel() for p in no_decay_params)
logger.info(
f"Optimizer groups: {n_decay:,} params with decay, {n_no_decay:,} params without decay"
)
builder = registry.get_optimizer(config.name)
return builder(param_groups, config)