kempnerforge.training.eval¶
Evaluation utilities for KempnerForge.
Provides run_eval for computing eval loss and perplexity on a held-out dataset. Works with any parallel model (FSDP, TP, PP) — same model reference, no unwrapping needed.
- kempnerforge.training.eval.run_eval(model, eval_dataloader, loss_fn, device, eval_steps, *, pp_schedule=None, pp_rank=None, pp_size=None, pp_group=None)¶
Run evaluation and return metrics.
- Parameters:
model (torch.nn.Module) – The model (FSDP-wrapped, TP-sharded, or plain).
eval_dataloader (torch.utils.data.DataLoader) – DataLoader yielding {“input_ids”, “labels”} batches.
loss_fn (callable) – Loss function (logits, labels) -> scalar tensor.
device (torch.device) – Device to move batches to.
eval_steps (int) – Number of eval batches to process.
pp_schedule – Pipeline parallel schedule (None for non-PP).
pp_rank (int | None) – This rank’s PP stage index.
pp_size (int | None) – Total number of PP stages.
pp_group – Process group for PP loss broadcast.
- Returns:
Dict with “eval/loss” and “eval/perplexity”.
- Return type: