Source code for kempnerforge.config.training
"""Training configuration."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
import torch
[docs]
class ActivationCheckpointing(StrEnum):
none = "none"
full = "full"
selective = "selective"
[docs]
@dataclass
class TrainConfig:
"""Training hyperparameters."""
batch_size: int = 8 # Per-device micro-batch size
seq_len: int = 2048
max_steps: int = 100000
grad_accum_steps: int = 1
grad_clip_norm: float = 1.0
seed: int = 42
compile_model: bool = True
mixed_precision: Literal["bf16", "fp16", "fp32", "fp8"] = "bf16"
activation_checkpointing: ActivationCheckpointing = ActivationCheckpointing.none
loss_fn: str = "cross_entropy" # Registry key for loss function
z_loss_weight: float = 0.0 # Logit magnitude regularizer (PaLM uses 1e-4, 0=disabled)
ce_chunk_size: int = 0 # Token chunk size for chunked_cross_entropy (0=auto 4096)
shutdown_timeout_sec: float = 600.0 # Graceful shutdown timeout before forced exit
nccl_health_check_interval: int = 0 # Check NCCL health every N steps (0=disabled)
def __post_init__(self) -> None:
if self.batch_size <= 0:
raise ValueError("batch_size must be positive")
if self.seq_len <= 0:
raise ValueError("seq_len must be positive")
if self.max_steps <= 0:
raise ValueError("max_steps must be positive")
if self.grad_accum_steps <= 0:
raise ValueError("grad_accum_steps must be positive")
if self.grad_clip_norm <= 0:
raise ValueError("grad_clip_norm must be positive")
@property
def param_dtype(self) -> torch.dtype:
"""Resolve mixed_precision to the master weight dtype.
FP8 uses bf16 master weights -- FP8 is a compute mode, not a storage dtype.
"""
import torch
_DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"fp8": torch.bfloat16, # FP8 compute with bf16 master weights
}
return _DTYPE_MAP[self.mixed_precision]
@property
def is_fp8(self) -> bool:
"""Whether FP8 mixed precision is enabled."""
return self.mixed_precision == "fp8"