Source code for kempnerforge.config.checkpoint

"""Checkpoint configuration."""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import StrEnum
from typing import Literal

from kempnerforge.config.registry import registry


[docs] class AsyncCheckpointMode(StrEnum): disabled = "disabled" async_ = "async" async_pinned = "async_with_pinned_mem"
[docs] @dataclass class DynamicCheckpointWindow: """A bounded step range with a registered checkpoint strategy. Inside ``[start, stop]`` the strategy decides which steps to save, and every such step is exempt from ``CheckpointConfig.keep_last_n`` retention. Outside the window the regular ``CheckpointConfig.interval`` cadence applies. ``"power2"`` (default) saves at ``start`` and at every ``start + 2^k`` while ``<= stop`` -- tight at the start of the window, doubling thereafter. New strategies register via ``@registry.register_dyn_ckpt_strategy(name)`` and become selectable by setting ``strategy``. """ start: int = 0 # 0 = capture initial weights before any training step stop: int = 512 strategy: str = "power2" def __post_init__(self) -> None: if self.start < 0: raise ValueError("dyn_ckpt_window.start must be >= 0") if self.stop < self.start: raise ValueError("dyn_ckpt_window.stop must be >= start") known = registry.list_dyn_ckpt_strategies() if self.strategy not in known: raise ValueError( f"unknown dyn_ckpt_window.strategy {self.strategy!r}; registered: {known}" )
[docs] def is_milestone(self, step: int) -> bool: """True iff the configured strategy fires at ``step``.""" return registry.get_dyn_ckpt_strategy(self.strategy)(self, step)
@registry.register_dyn_ckpt_strategy("power2") def _power2_strategy(window: DynamicCheckpointWindow, step: int) -> bool: """Save at ``start`` and at every ``start + 2^k`` while ``<= stop``.""" if step < window.start or step > window.stop: return False offset = step - window.start return offset == 0 or (offset & (offset - 1)) == 0
[docs] @dataclass class CheckpointConfig: """Checkpointing settings.""" dir: str = "checkpoints" interval: int = 1000 # save every N steps; outside any dyn_ckpt_window dyn_ckpt_window: DynamicCheckpointWindow | None = None # opt-in dense window async_mode: AsyncCheckpointMode = AsyncCheckpointMode.disabled keep_last_n: int = 3 # recent ckpts kept (<=0 keeps all); dynamic milestones always kept load_path: str | None = None # Path to load from (for resumption) export_dtype: Literal["float32", "bfloat16"] = "bfloat16" exclude_from_loading: list[str] = field(default_factory=list) # If the saved checkpoint's VLM freeze metadata differs from the current # config's freeze specs, the load path raises by default. Setting this # to True downgrades the mismatch to a warning. Useful when intentionally # switching from frozen to trainable mid-training. ignore_freeze_mismatch: bool = False def __post_init__(self) -> None: if self.interval <= 0: raise ValueError("interval must be positive")
[docs] def should_save(self, step: int) -> bool: """Whether to write a checkpoint at ``step``. Inside ``dyn_ckpt_window``: the registered strategy decides (default ``"power2"`` saves at ``start`` and each ``start + 2^k`` while ``<= stop``). Outside the window: every ``interval`` steps. Dynamic milestones are exempt from ``keep_last_n`` (see ``CheckpointManager._cleanup``). """ w = self.dyn_ckpt_window if w is not None and w.start <= step <= w.stop: return w.is_milestone(step) return step % self.interval == 0
[docs] def is_dynamic_milestone(self, step: int) -> bool: """True if ``step`` is a milestone of the configured ``dyn_ckpt_window``. ``CheckpointManager._cleanup`` excludes these from ``keep_last_n`` so the dense early-window checkpoints survive a finite retention. """ return self.dyn_ckpt_window is not None and self.dyn_ckpt_window.is_milestone(step)