kempnerforge.training¶
Training loop and optimization for KempnerForge.
- Public API:
build_loss_fn / build_optimizer / build_scheduler: Component factories
run_eval: Evaluation loop (loss + perplexity)
maybe_no_sync: Gradient accumulation helper
- kempnerforge.training.build_loss_fn(config)[source]¶
Build a composed loss function from training config.
Follows the build_optimizer pattern: config in, callable out. Binds chunk_size for chunked CE and composes z-loss, so the caller gets a clean
(logits, labels) -> Tensorinterface.- Return type:
- kempnerforge.training.build_optimizer(model, config)[source]¶
Construct an optimizer with per-parameter-group weight decay settings.
- Parameters:
model (torch.nn.Module) – Model whose parameters to optimize.
config (OptimizerConfig) – Optimizer configuration.
- Returns:
Configured optimizer instance.
- Return type:
- kempnerforge.training.build_scheduler(optimizer, config, max_steps)[source]¶
Build a LR scheduler from config.
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to schedule.
config (SchedulerConfig) – Scheduler configuration.
max_steps (int) – Total training steps (used to compute decay length).
- Returns:
PyTorch LambdaLR scheduler.
- Return type:
- kempnerforge.training.maybe_no_sync(model, micro_step, grad_accum_steps)[source]¶
Skip gradient sync on intermediate accumulation steps.
On the last micro-step, gradients are synchronized normally. On earlier micro-steps, sync is skipped to avoid redundant all-reduces.
Works with FSDP2 (which implements
no_sync()on the module). For non-distributed models, this is a no-op.- Parameters:
model (torch.nn.Module)
micro_step (int)
grad_accum_steps (int)
- Return type:
Generator[None, None, None]
- kempnerforge.training.run_eval(model, eval_dataloader, loss_fn, device, eval_steps, *, pp_schedule=None, pp_rank=None, pp_size=None, pp_group=None)¶
Run evaluation and return metrics.
- Parameters:
model (torch.nn.Module) – The model (FSDP-wrapped, TP-sharded, or plain).
eval_dataloader (torch.utils.data.DataLoader) – DataLoader yielding {“input_ids”, “labels”} batches.
loss_fn (callable) – Loss function (logits, labels) -> scalar tensor.
device (torch.device) – Device to move batches to.
eval_steps (int) – Number of eval batches to process.
pp_schedule – Pipeline parallel schedule (None for non-PP).
pp_rank (int | None) – This rank’s PP stage index.
pp_size (int | None) – Total number of PP stages.
pp_group – Process group for PP loss broadcast.
- Returns:
Dict with “eval/loss” and “eval/perplexity”.
- Return type:
Modules
Evaluation utilities for KempnerForge. |
|
Gradient utilities for distributed training. |
|
Training loop hooks for extensibility without forking train.py. |
|
Loss function registry for KempnerForge. |
|
Optimizer construction for KempnerForge. |
|
Learning rate schedulers for KempnerForge. |