Source code for kempnerforge.config.checkpoint
"""Checkpoint configuration."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Literal
[docs]
class AsyncCheckpointMode(StrEnum):
disabled = "disabled"
async_ = "async"
async_pinned = "async_with_pinned_mem"
[docs]
@dataclass
class CheckpointConfig:
"""Checkpointing settings."""
dir: str = "checkpoints"
interval: int = 1000 # Save every N steps
async_mode: AsyncCheckpointMode = AsyncCheckpointMode.disabled
keep_last_n: int = 3 # Number of checkpoints to retain
load_path: str | None = None # Path to load from (for resumption)
export_dtype: Literal["float32", "bfloat16"] = "bfloat16"
exclude_from_loading: list[str] = field(default_factory=list)
def __post_init__(self) -> None:
if self.interval <= 0:
raise ValueError("interval must be positive")
if self.keep_last_n < 1:
raise ValueError("keep_last_n must be >= 1")