kempnerforge.training

Training loop and optimization for KempnerForge.

Public API:
  • build_loss_fn / build_optimizer / build_scheduler: Component factories

  • run_eval: Evaluation loop (loss + perplexity)

  • maybe_no_sync: Gradient accumulation helper

kempnerforge.training.build_loss_fn(config)[source]

Build a composed loss function from training config.

Follows the build_optimizer pattern: config in, callable out. Binds chunk_size for chunked CE and composes z-loss, so the caller gets a clean (logits, labels) -> Tensor interface.

Return type:

Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

kempnerforge.training.build_optimizer(model, config)[source]

Construct an optimizer with per-parameter-group weight decay settings.

Parameters:
Returns:

Configured optimizer instance.

Return type:

torch.optim.Optimizer

kempnerforge.training.build_scheduler(optimizer, config, max_steps)[source]

Build a LR scheduler from config.

Parameters:
Returns:

PyTorch LambdaLR scheduler.

Return type:

torch.optim.lr_scheduler.LambdaLR

kempnerforge.training.maybe_no_sync(model, micro_step, grad_accum_steps)[source]

Skip gradient sync on intermediate accumulation steps.

On the last micro-step, gradients are synchronized normally. On earlier micro-steps, sync is skipped to avoid redundant all-reduces.

Works with FSDP2 (which implements no_sync() on the module). For non-distributed models, this is a no-op.

Parameters:
Return type:

Generator[None, None, None]

kempnerforge.training.run_eval(model, eval_dataloader, loss_fn, device, eval_steps, *, pp_schedule=None, pp_rank=None, pp_size=None, pp_group=None)

Run evaluation and return metrics.

Parameters:
  • model (torch.nn.Module) – The model (FSDP-wrapped, TP-sharded, or plain).

  • eval_dataloader (torch.utils.data.DataLoader) – DataLoader yielding {“input_ids”, “labels”} batches.

  • loss_fn (callable) – Loss function (logits, labels) -> scalar tensor.

  • device (torch.device) – Device to move batches to.

  • eval_steps (int) – Number of eval batches to process.

  • pp_schedule – Pipeline parallel schedule (None for non-PP).

  • pp_rank (int | None) – This rank’s PP stage index.

  • pp_size (int | None) – Total number of PP stages.

  • pp_group – Process group for PP loss broadcast.

Returns:

Dict with “eval/loss” and “eval/perplexity”.

Return type:

dict[str, float]

Modules

eval

Evaluation utilities for KempnerForge.

grad

Gradient utilities for distributed training.

hooks

Training loop hooks for extensibility without forking train.py.

loss

Loss function registry for KempnerForge.

optimizer

Optimizer construction for KempnerForge.

scheduler

Learning rate schedulers for KempnerForge.