kempnerforge.training.eval

Evaluation utilities for KempnerForge.

Provides run_eval for computing eval loss and perplexity on a held-out dataset. Works with any parallel model (FSDP, TP, PP) — same model reference, no unwrapping needed.

kempnerforge.training.eval.run_eval(model, eval_dataloader, loss_fn, device, eval_steps, *, pp_schedule=None, pp_rank=None, pp_size=None, pp_group=None)

Run evaluation and return metrics.

Parameters:
  • model (torch.nn.Module) – The model (FSDP-wrapped, TP-sharded, or plain).

  • eval_dataloader (torch.utils.data.DataLoader) – DataLoader yielding {“input_ids”, “labels”} batches.

  • loss_fn (callable) – Loss function (logits, labels) -> scalar tensor.

  • device (torch.device) – Device to move batches to.

  • eval_steps (int) – Number of eval batches to process.

  • pp_schedule – Pipeline parallel schedule (None for non-PP).

  • pp_rank (int | None) – This rank’s PP stage index.

  • pp_size (int | None) – Total number of PP stages.

  • pp_group – Process group for PP loss broadcast.

Returns:

Dict with “eval/loss” and “eval/perplexity”.

Return type:

dict[str, float]