Hooks¶
The TrainingHook interface in
kempnerforge/training/hooks.py
is the supported way to run custom logic during training without
forking scripts/train.py.
Interface¶
from kempnerforge.training.hooks import TrainingHook, StepContext, HookRunner
class MyHook(TrainingHook):
def on_train_begin(self, config): ...
def on_step_end(self, ctx: StepContext): ...
def on_eval_end(self, metrics: dict[str, float], step: int): ...
def on_checkpoint_save(self, step: int, path: str): ...
def on_train_end(self, step: int, tokens_seen: int): ...
TrainingHook is a concrete base class with empty method bodies —
subclass and override only what you need. The HookRunner dispatches
calls to a list of hooks; when the list is empty, each entry point
short-circuits on an if not self.hooks: return check, so an unused
runner costs one branch per call site.
StepContext¶
Passed to on_step_end, every step:
@dataclass
class StepContext:
step: int
loss: float
grad_norm: float
lr: float
tokens_seen: int
model: torch.nn.Module
optimizer: torch.optim.Optimizer
model and optimizer are the live references — you can read
parameter values and optimizer state. Gradients have already been
zeroed by the time on_step_end fires. The loop order is:
optimizer.step()
scheduler.step()
optimizer.zero_grad() # grads cleared here
...
tracker.end_step(...)
hook_runner.on_step_end(StepContext(...))
If your hook needs to inspect gradients, it can’t — you need to
modify train.py and capture them before optimizer.zero_grad(). The
wandb.log of p.grad.norm() in the module’s docstring example
requires capturing the norms during the backward pass (via a
pre-zero_grad hook), not reading p.grad in on_step_end.
Lifecycle events¶
Event |
Fires at |
What you get |
|---|---|---|
|
Once, after setup, before the loop |
Full |
|
Every training step, after metrics |
|
|
After every eval round |
|
|
After a DCP checkpoint is written |
|
|
Once, after the loop exits (normal or shutdown) |
Final counters |
Ordering at step boundaries:
… step body …
tracker.end_step(...) # metrics dispatched to WandB/TB
hook_runner.on_step_end(StepContext(...))
# MoE-specific metrics logged
# NCCL health / eval / profiler / checkpoint / shutdown
on_step_end fires after the metrics tracker — logging from a
hook into WandB happens on the same step number, not a step off. At
the end of a checkpoint-save step, the order is on_step_end first,
then on_checkpoint_save.
Registering hooks¶
Hooks are created in-process, not from TOML:
# In a custom launcher that imports scripts/train.py as a library,
# or in a fork that monkey-patches the hook_runner after construction:
from kempnerforge.training.hooks import HookRunner
hook_runner = HookRunner([MyHook1(), MyHook2()])
scripts/train.py instantiates HookRunner() with no hooks at line
502. To register hooks today, you either fork train.py to pass a
populated list, or import-and-patch from a custom entry point.
A config-driven hook registry is not yet wired up.
When to write a hook vs fork train.py¶
Task |
Hook or fork |
|---|---|
Custom WandB logging (histograms, attention maps from eval) |
Hook ( |
External heartbeat / health ping |
Hook ( |
Upload checkpoint to S3 after save |
Hook ( |
Inspect gradients |
Fork (grads are zeroed before |
Inject a custom optimizer step |
Fork (no |
Change the loss function mid-run |
Fork ( |
Custom learning-rate schedule not in the registry |
Register a new scheduler, not a hook |
The guiding principle: hooks observe, they don’t intervene. Anything that needs to modify the step (not just log about it) belongs in a fork or a new registry entry.
Example: per-step gradient-norm histogram¶
class GradHistogramHook(TrainingHook):
"""Stash pre-clip grad norms during backward, log them in on_step_end.
This requires adding a forward hook during on_train_begin because
gradients are zeroed by the time on_step_end fires.
"""
def __init__(self) -> None:
self._norms: dict[str, float] = {}
def on_train_begin(self, config) -> None:
# Subclass or patch training loop to populate _norms during backward
pass
def on_step_end(self, ctx: StepContext) -> None:
if self._norms:
wandb.log(
{f"grad_norm/{n}": v for n, v in self._norms.items()},
step=ctx.step,
)
self._norms.clear()
The pragmatic version: for a one-off gradient-norm histogram, just
edit scripts/train.py to log it directly after the microbatch loop —
the hook machinery adds complexity when you only ever need one
logger.
See also¶
Training loop § Metrics and hooks — exact ordering of
tracker.end_stepvshook_runner.on_step_end.Metrics and profiling —
MetricsTracker, the out-of-the-box logger your hook complements.Evaluation § In-loop — where
on_eval_endfires.Checkpointing — where
on_checkpoint_savefires.