Source code for kempnerforge.training.hooks
"""Training loop hooks for extensibility without forking train.py.
Researchers subclass ``TrainingHook`` and override only the methods they need.
Hooks run at defined points in the training loop; when no hooks are registered,
the overhead is a single empty-list check per call site.
Usage:
from kempnerforge.training.hooks import TrainingHook, StepContext
class GradHistogramHook(TrainingHook):
def on_step_end(self, ctx: StepContext) -> None:
for name, p in ctx.model.named_parameters():
if p.grad is not None:
wandb.log({f"grad_norm/{name}": p.grad.norm().item()}, step=ctx.step)
hooks = [GradHistogramHook()]
runner = HookRunner(hooks)
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from kempnerforge.config.job import JobConfig
[docs]
@dataclass
class StepContext:
"""Per-step context passed to hooks after each training step."""
step: int
loss: float
grad_norm: float
lr: float
tokens_seen: int
model: torch.nn.Module
optimizer: torch.optim.Optimizer
[docs]
class TrainingHook:
"""Base class for training hooks. Override only the methods you need."""
[docs]
def on_train_begin(self, config: JobConfig) -> None:
"""Called once after setup, before the training loop starts."""
[docs]
def on_step_end(self, ctx: StepContext) -> None:
"""Called after each optimizer step + metrics logging."""
[docs]
def on_eval_end(self, metrics: dict[str, float], step: int) -> None:
"""Called after each evaluation round completes."""
[docs]
def on_checkpoint_save(self, step: int, path: str) -> None:
"""Called after a checkpoint is saved."""
[docs]
def on_train_end(self, step: int, tokens_seen: int) -> None:
"""Called after the training loop exits."""
[docs]
class HookRunner:
"""Dispatches hook calls to registered hooks. Zero cost when empty."""
[docs]
def __init__(self, hooks: list[TrainingHook] | None = None) -> None:
self.hooks: list[TrainingHook] = hooks or []
[docs]
def on_train_begin(self, config: JobConfig) -> None:
if not self.hooks:
return
for hook in self.hooks:
hook.on_train_begin(config)
[docs]
def on_step_end(self, ctx: StepContext) -> None:
if not self.hooks:
return
for hook in self.hooks:
hook.on_step_end(ctx)
[docs]
def on_eval_end(self, metrics: dict[str, float], step: int) -> None:
if not self.hooks:
return
for hook in self.hooks:
hook.on_eval_end(metrics, step)
[docs]
def on_checkpoint_save(self, step: int, path: str) -> None:
if not self.hooks:
return
for hook in self.hooks:
hook.on_checkpoint_save(step, path)
[docs]
def on_train_end(self, step: int, tokens_seen: int) -> None:
if not self.hooks:
return
for hook in self.hooks:
hook.on_train_end(step, tokens_seen)