Source code for kempnerforge.training.eval

"""Evaluation utilities for KempnerForge.

Provides `run_eval` for computing eval loss and perplexity on a held-out
dataset. Works with any parallel model (FSDP, TP, PP) — same model reference,
no unwrapping needed.
"""

from __future__ import annotations

import logging
import math
from datetime import timedelta

import torch
import torch.distributed as dist

logger = logging.getLogger(__name__)

# Per-operation timeout for the PP eval loss broadcast. Shorter than the
# 1800s process-group default so a diverged PP stage surfaces fast rather
# than freezing eval for half an hour.
_EVAL_BROADCAST_TIMEOUT_SEC = 300.0


[docs] def should_build_eval_dataloader(eval_enabled: bool, is_vlm: bool) -> tuple[bool, bool]: """Decide whether to build an eval dataloader and whether to warn. The training loop calls ``run_eval(model, eval_dataloader, ...)`` which invokes ``model(input_ids)`` — this does not match ``VLMWrapper.forward(pixel_values, input_ids, labels)``. VLM configs with ``eval.enabled=true`` would crash on the first eval interval. This helper gates the eval setup: for VLM configs it suppresses eval and flags that a warning should be logged so users see their eval setting was ignored. VLM eval support is a tracked follow-up. Returns ``(should_build, should_warn_vlm_skip)``. """ if eval_enabled and is_vlm: return False, True return eval_enabled, False
@torch.no_grad() def run_eval( model: torch.nn.Module, eval_dataloader: torch.utils.data.DataLoader, loss_fn: callable, # type: ignore[reportGeneralTypeIssues] device: torch.device, eval_steps: int, *, pp_schedule=None, pp_rank: int | None = None, pp_size: int | None = None, pp_group=None, ) -> dict[str, float]: """Run evaluation and return metrics. Args: model: The model (FSDP-wrapped, TP-sharded, or plain). eval_dataloader: DataLoader yielding {"input_ids", "labels"} batches. loss_fn: Loss function (logits, labels) -> scalar tensor. device: Device to move batches to. eval_steps: Number of eval batches to process. pp_schedule: Pipeline parallel schedule (None for non-PP). pp_rank: This rank's PP stage index. pp_size: Total number of PP stages. pp_group: Process group for PP loss broadcast. Returns: Dict with "eval/loss" and "eval/perplexity". """ model.eval() if pp_schedule is not None: # --- PP eval path --- input_ids_list, labels_list = [], [] eval_iter = iter(eval_dataloader) for _ in range(eval_steps): try: batch = next(eval_iter) except StopIteration: eval_iter = iter(eval_dataloader) batch = next(eval_iter) input_ids_list.append(batch["input_ids"].to(device)) labels_list.append(batch["labels"].to(device)) full_input = torch.cat(input_ids_list, dim=0) full_labels = torch.cat(labels_list, dim=0) is_first = pp_rank == 0 is_last = pp_rank == pp_size - 1 # type: ignore[reportOptionalOperand] pp_losses: list[torch.Tensor] = [] if is_first: pp_schedule.step(full_input, target=full_labels, losses=pp_losses) elif is_last: pp_schedule.step(target=full_labels, losses=pp_losses) else: pp_schedule.step() if is_last and pp_losses: avg_loss = sum(loss.item() for loss in pp_losses) / len(pp_losses) else: avg_loss = 0.0 loss_tensor = torch.tensor([avg_loss], device=device) work = dist.broadcast( loss_tensor, group_src=pp_size - 1, # type: ignore[reportOptionalOperand] group=pp_group, async_op=True, ) try: done = work.wait(timeout=timedelta(seconds=_EVAL_BROADCAST_TIMEOUT_SEC)) # type: ignore[reportOptionalMemberAccess] except RuntimeError as e: logger.error( f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; " f"a PP stage is likely wedged. Reporting nan loss. Underlying: {e}" ) avg_loss = float("nan") else: if done is False: logger.error( f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; " "reporting nan loss." ) avg_loss = float("nan") else: avg_loss = loss_tensor[0].item() else: # --- Standard eval path --- total_loss = 0.0 eval_iter = iter(eval_dataloader) for _ in range(eval_steps): try: batch = next(eval_iter) except StopIteration: eval_iter = iter(eval_dataloader) batch = next(eval_iter) input_ids = batch["input_ids"].to(device) labels = batch["labels"].to(device) logits = model(input_ids) loss = loss_fn(logits, labels) total_loss += loss.item() avg_loss = total_loss / eval_steps model.train() return {"eval/loss": avg_loss, "eval/perplexity": math.exp(min(avg_loss, 20.0))}