kempnerforge.config.training

Training configuration.

Classes

ActivationCheckpointing

TrainConfig

Training hyperparameters.

class kempnerforge.config.training.ActivationCheckpointing[source]

Bases: StrEnum

none = 'none'
full = 'full'
selective = 'selective'
__new__(value)
class kempnerforge.config.training.TrainConfig[source]

Bases: object

Training hyperparameters.

batch_size: int = 8
seq_len: int = 2048
max_steps: int = 100000
grad_accum_steps: int = 1
grad_clip_norm: float = 1.0
seed: int = 42
data_seed: int | None = None
compile_model: bool = True
mixed_precision: Literal['bf16', 'fp16', 'fp32', 'fp8'] = 'bf16'
activation_checkpointing: ActivationCheckpointing = 'none'
loss_fn: str = 'cross_entropy'
z_loss_weight: float = 0.0
ce_chunk_size: int = 0
shutdown_timeout_sec: float = 600.0
nccl_health_check_interval: int = 0
property param_dtype: torch.dtype

Resolve mixed_precision to the master weight dtype.

FP8 uses bf16 master weights – FP8 is a compute mode, not a storage dtype.

property is_fp8: bool

Whether FP8 mixed precision is enabled.

property effective_data_seed: int

Seed for data shuffling / batch composition.

Falls back to seed when data_seed is unset, so existing configs reproduce their current trajectory. Kept independent from seed (parameter init) so stability studies can vary batch order while holding initialization fixed. Must stay identical across data-parallel ranks so the global shuffle is consistent before rank partitioning.

__init__(batch_size=8, seq_len=2048, max_steps=100000, grad_accum_steps=1, grad_clip_norm=1.0, seed=42, data_seed=None, compile_model=True, mixed_precision='bf16', activation_checkpointing=ActivationCheckpointing.none, loss_fn='cross_entropy', z_loss_weight=0.0, ce_chunk_size=0, shutdown_timeout_sec=600.0, nccl_health_check_interval=0)
Parameters:
Return type:

None