kempnerforge.config.training

Training configuration.

Classes

ActivationCheckpointing

TrainConfig

Training hyperparameters.

class kempnerforge.config.training.ActivationCheckpointing[source]

Bases: StrEnum

none = 'none'
full = 'full'
selective = 'selective'
__new__(value)
class kempnerforge.config.training.TrainConfig[source]

Bases: object

Training hyperparameters.

batch_size: int = 8
seq_len: int = 2048
max_steps: int = 100000
grad_accum_steps: int = 1
grad_clip_norm: float = 1.0
seed: int = 42
compile_model: bool = True
mixed_precision: Literal['bf16', 'fp16', 'fp32', 'fp8'] = 'bf16'
activation_checkpointing: ActivationCheckpointing = 'none'
loss_fn: str = 'cross_entropy'
z_loss_weight: float = 0.0
ce_chunk_size: int = 0
shutdown_timeout_sec: float = 600.0
nccl_health_check_interval: int = 0
property param_dtype: torch.dtype

Resolve mixed_precision to the master weight dtype.

FP8 uses bf16 master weights – FP8 is a compute mode, not a storage dtype.

property is_fp8: bool

Whether FP8 mixed precision is enabled.

__init__(batch_size=8, seq_len=2048, max_steps=100000, grad_accum_steps=1, grad_clip_norm=1.0, seed=42, compile_model=True, mixed_precision='bf16', activation_checkpointing=ActivationCheckpointing.none, loss_fn='cross_entropy', z_loss_weight=0.0, ce_chunk_size=0, shutdown_timeout_sec=600.0, nccl_health_check_interval=0)
Parameters:
Return type:

None