kempnerforge.config.training¶
Training configuration.
Classes
Training hyperparameters. |
- class kempnerforge.config.training.ActivationCheckpointing[source]¶
Bases:
StrEnum- none = 'none'¶
- full = 'full'¶
- selective = 'selective'¶
- __new__(value)¶
- class kempnerforge.config.training.TrainConfig[source]¶
Bases:
objectTraining hyperparameters.
- activation_checkpointing: ActivationCheckpointing = 'none'¶
- property param_dtype: torch.dtype¶
Resolve mixed_precision to the master weight dtype.
FP8 uses bf16 master weights – FP8 is a compute mode, not a storage dtype.
- property effective_data_seed: int¶
Seed for data shuffling / batch composition.
Falls back to
seedwhendata_seedis unset, so existing configs reproduce their current trajectory. Kept independent fromseed(parameter init) so stability studies can vary batch order while holding initialization fixed. Must stay identical across data-parallel ranks so the global shuffle is consistent before rank partitioning.
- __init__(batch_size=8, seq_len=2048, max_steps=100000, grad_accum_steps=1, grad_clip_norm=1.0, seed=42, data_seed=None, compile_model=True, mixed_precision='bf16', activation_checkpointing=ActivationCheckpointing.none, loss_fn='cross_entropy', z_loss_weight=0.0, ce_chunk_size=0, shutdown_timeout_sec=600.0, nccl_health_check_interval=0)¶
- Parameters:
batch_size (int)
seq_len (int)
max_steps (int)
grad_accum_steps (int)
grad_clip_norm (float)
seed (int)
data_seed (int | None)
compile_model (bool)
mixed_precision (Literal['bf16', 'fp16', 'fp32', 'fp8'])
activation_checkpointing (ActivationCheckpointing)
loss_fn (str)
z_loss_weight (float)
ce_chunk_size (int)
shutdown_timeout_sec (float)
nccl_health_check_interval (int)
- Return type:
None