kempnerforge.training.hooks¶
Training loop hooks for extensibility without forking train.py.
Researchers subclass TrainingHook and override only the methods they need.
Hooks run at defined points in the training loop; when no hooks are registered,
the overhead is a single empty-list check per call site.
- Usage:
from kempnerforge.training.hooks import TrainingHook, StepContext
- class GradHistogramHook(TrainingHook):
- def on_step_end(self, ctx: StepContext) -> None:
- for name, p in ctx.model.named_parameters():
- if p.grad is not None:
wandb.log({f”grad_norm/{name}”: p.grad.norm().item()}, step=ctx.step)
hooks = [GradHistogramHook()] runner = HookRunner(hooks)
Classes
Dispatches hook calls to registered hooks. |
|
Per-step context passed to hooks after each training step. |
|
Base class for training hooks. |
- class kempnerforge.training.hooks.StepContext[source]¶
Bases:
objectPer-step context passed to hooks after each training step.
- model: torch.nn.Module¶
- optimizer: torch.optim.Optimizer¶
- __init__(step, loss, grad_norm, lr, tokens_seen, model, optimizer)¶
- Parameters:
step (int)
loss (float)
grad_norm (float)
lr (float)
tokens_seen (int)
model (torch.nn.Module)
optimizer (torch.optim.Optimizer)
- Return type:
None
- class kempnerforge.training.hooks.TrainingHook[source]¶
Bases:
objectBase class for training hooks. Override only the methods you need.
- on_train_begin(config)[source]¶
Called once after setup, before the training loop starts.
- Parameters:
config (JobConfig)
- Return type:
None
- on_step_end(ctx)[source]¶
Called after each optimizer step + metrics logging.
- Parameters:
ctx (StepContext)
- Return type:
None
- class kempnerforge.training.hooks.HookRunner[source]¶
Bases:
objectDispatches hook calls to registered hooks. Zero cost when empty.
- __init__(hooks=None)[source]¶
- Parameters:
hooks (list[TrainingHook] | None)
- Return type:
None
- hooks: list[TrainingHook]¶
- on_step_end(ctx)[source]¶
- Parameters:
ctx (StepContext)
- Return type:
None