kempnerforge.config.training¶
Training configuration.
Classes
Training hyperparameters. |
- class kempnerforge.config.training.ActivationCheckpointing[source]¶
Bases:
StrEnum- none = 'none'¶
- full = 'full'¶
- selective = 'selective'¶
- __new__(value)¶
- class kempnerforge.config.training.TrainConfig[source]¶
Bases:
objectTraining hyperparameters.
- activation_checkpointing: ActivationCheckpointing = 'none'¶
- property param_dtype: torch.dtype¶
Resolve mixed_precision to the master weight dtype.
FP8 uses bf16 master weights – FP8 is a compute mode, not a storage dtype.
- __init__(batch_size=8, seq_len=2048, max_steps=100000, grad_accum_steps=1, grad_clip_norm=1.0, seed=42, compile_model=True, mixed_precision='bf16', activation_checkpointing=ActivationCheckpointing.none, loss_fn='cross_entropy', z_loss_weight=0.0, ce_chunk_size=0, shutdown_timeout_sec=600.0, nccl_health_check_interval=0)¶
- Parameters:
batch_size (int)
seq_len (int)
max_steps (int)
grad_accum_steps (int)
grad_clip_norm (float)
seed (int)
compile_model (bool)
mixed_precision (Literal['bf16', 'fp16', 'fp32', 'fp8'])
activation_checkpointing (ActivationCheckpointing)
loss_fn (str)
z_loss_weight (float)
ce_chunk_size (int)
shutdown_timeout_sec (float)
nccl_health_check_interval (int)
- Return type:
None