Source code for kempnerforge.config.scheduler

"""LR scheduler configuration."""

from __future__ import annotations

from dataclasses import dataclass
from enum import StrEnum


[docs] class SchedulerType(StrEnum): cosine = "cosine" linear = "linear" wsd = "wsd" # warmup-stable-decay constant = "constant" # warmup then flat LR rex = "rex" # polynomial decay: (1 - t/T)^alpha none = "none" # constant LR (for schedule-free optimizers)
[docs] @dataclass class SchedulerConfig: """Learning rate schedule settings.""" name: SchedulerType = SchedulerType.cosine warmup_steps: int = 2000 decay_steps: int | None = None # None -> decay over remaining steps min_lr_ratio: float = 0.1 # min_lr = lr * min_lr_ratio # WSD-specific stable_steps: int | None = None # For WSD: steps at constant LR wsd_decay_type: str = "cosine" # WSD cooldown shape: "cosine", "linear", "sqrt" # REX-specific rex_alpha: float = 1.0 # Exponent for REX: (1 - t/T)^alpha def __post_init__(self) -> None: if self.warmup_steps < 0: raise ValueError("warmup_steps must be non-negative") if not (0 <= self.min_lr_ratio <= 1): raise ValueError("min_lr_ratio must be in [0, 1]") if self.wsd_decay_type not in ("cosine", "linear", "sqrt"): raise ValueError( f"wsd_decay_type must be 'cosine', 'linear', or 'sqrt', got '{self.wsd_decay_type}'" ) if self.rex_alpha <= 0: raise ValueError("rex_alpha must be positive")