Source code for kempnerforge.training.scheduler

"""Learning rate schedulers for KempnerForge.

All schedulers are step-based (not epoch-based) and compose a warmup
phase with a decay phase:

  - Cosine:    warmup → cosine decay to min_lr
  - Linear:    warmup → linear decay to min_lr
  - WSD:       warmup → stable → decay to min_lr (cosine/linear/sqrt cooldown)
  - Constant:  warmup → flat LR (for ablations)
  - REX:       warmup → polynomial decay (1 - t/T)^alpha
  - None:      constant factor=1.0 (for schedule-free optimizers)
"""

from __future__ import annotations

import math

import torch

from kempnerforge.config.registry import registry
from kempnerforge.config.schema import SchedulerConfig, SchedulerType


def _warmup_factor(step: int, warmup_steps: int) -> float:
    """Linear warmup from 0 to 1 over warmup_steps."""
    if warmup_steps == 0:
        return 1.0
    return min(1.0, step / warmup_steps)


def _cosine_decay(step: int, total_steps: int, min_ratio: float) -> float:
    """Cosine decay from 1 to min_ratio over total_steps."""
    if total_steps <= 0:
        return 1.0
    progress = min(1.0, step / total_steps)
    return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * progress))


def _linear_decay(step: int, total_steps: int, min_ratio: float) -> float:
    """Linear decay from 1 to min_ratio over total_steps."""
    if total_steps <= 0:
        return 1.0
    progress = min(1.0, step / total_steps)
    return 1.0 - (1.0 - min_ratio) * progress


def _sqrt_decay(step: int, total_steps: int, min_ratio: float) -> float:
    """Sqrt decay from 1 to min_ratio over total_steps.

    Stays higher than linear for longer, then drops more steeply near the end.
    Shape: ``min_ratio + (1 - min_ratio) * (1 - progress)^0.5``
    """
    if total_steps <= 0:
        return 1.0
    progress = min(1.0, step / total_steps)
    return min_ratio + (1.0 - min_ratio) * (1.0 - progress) ** 0.5


@registry.register_scheduler("cosine")
def _build_cosine(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    warmup = config.warmup_steps
    min_ratio = config.min_lr_ratio
    decay_steps = config.decay_steps or (max_steps - warmup)

    def lr_fn(step: int) -> float:
        if step < warmup:
            return _warmup_factor(step, warmup)
        return _cosine_decay(step - warmup, decay_steps, min_ratio)

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)


@registry.register_scheduler("linear")
def _build_linear(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    warmup = config.warmup_steps
    min_ratio = config.min_lr_ratio
    decay_steps = config.decay_steps or (max_steps - warmup)

    def lr_fn(step: int) -> float:
        if step < warmup:
            return _warmup_factor(step, warmup)
        return _linear_decay(step - warmup, decay_steps, min_ratio)

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)


@registry.register_scheduler("wsd")
def _build_wsd(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    warmup = config.warmup_steps
    min_ratio = config.min_lr_ratio
    stable = config.stable_steps or 0
    decay_steps = config.decay_steps or (max_steps - warmup - stable)

    decay_fn = {
        "cosine": _cosine_decay,
        "linear": _linear_decay,
        "sqrt": _sqrt_decay,
    }[config.wsd_decay_type]

    def lr_fn(step: int) -> float:
        if step < warmup:
            return _warmup_factor(step, warmup)
        if step < warmup + stable:
            return 1.0
        return decay_fn(step - warmup - stable, decay_steps, min_ratio)

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)


@registry.register_scheduler("constant")
def _build_constant(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    warmup = config.warmup_steps

    def lr_fn(step: int) -> float:
        if step < warmup:
            return _warmup_factor(step, warmup)
        return 1.0

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)


@registry.register_scheduler("rex")
def _build_rex(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    warmup = config.warmup_steps
    min_ratio = config.min_lr_ratio
    alpha = config.rex_alpha
    decay_steps = config.decay_steps or (max_steps - warmup)

    def lr_fn(step: int) -> float:
        if step < warmup:
            return _warmup_factor(step, warmup)
        progress = min(1.0, (step - warmup) / max(decay_steps, 1))
        return max(min_ratio, (1.0 - progress) ** alpha)

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fn)


@registry.register_scheduler("none")
def _build_none(
    optimizer: torch.optim.Optimizer,
    config: SchedulerConfig,
    max_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: 1.0)


[docs] def build_scheduler( optimizer: torch.optim.Optimizer, config: SchedulerConfig, max_steps: int, ) -> torch.optim.lr_scheduler.LambdaLR: """Build a LR scheduler from config. Args: optimizer: Optimizer to schedule. config: Scheduler configuration. max_steps: Total training steps (used to compute decay length). Returns: PyTorch LambdaLR scheduler. """ name = config.name.value if isinstance(config.name, SchedulerType) else config.name builder = registry.get_scheduler(name) return builder(optimizer, config, max_steps)