kempnerforge.training.hooks

Training loop hooks for extensibility without forking train.py.

Researchers subclass TrainingHook and override only the methods they need. Hooks run at defined points in the training loop; when no hooks are registered, the overhead is a single empty-list check per call site.

Usage:

from kempnerforge.training.hooks import TrainingHook, StepContext

class GradHistogramHook(TrainingHook):
def on_step_end(self, ctx: StepContext) -> None:
for name, p in ctx.model.named_parameters():
if p.grad is not None:

wandb.log({f”grad_norm/{name}”: p.grad.norm().item()}, step=ctx.step)

hooks = [GradHistogramHook()] runner = HookRunner(hooks)

Classes

HookRunner

Dispatches hook calls to registered hooks.

StepContext

Per-step context passed to hooks after each training step.

TrainingHook

Base class for training hooks.

class kempnerforge.training.hooks.StepContext[source]

Bases: object

Per-step context passed to hooks after each training step.

step: int
loss: float
grad_norm: float
lr: float
tokens_seen: int
model: torch.nn.Module
optimizer: torch.optim.Optimizer
__init__(step, loss, grad_norm, lr, tokens_seen, model, optimizer)
Parameters:
Return type:

None

class kempnerforge.training.hooks.TrainingHook[source]

Bases: object

Base class for training hooks. Override only the methods you need.

on_train_begin(config)[source]

Called once after setup, before the training loop starts.

Parameters:

config (JobConfig)

Return type:

None

on_step_end(ctx)[source]

Called after each optimizer step + metrics logging.

Parameters:

ctx (StepContext)

Return type:

None

on_eval_end(metrics, step)[source]

Called after each evaluation round completes.

Parameters:
Return type:

None

on_checkpoint_save(step, path)[source]

Called after a checkpoint is saved.

Parameters:
Return type:

None

on_train_end(step, tokens_seen)[source]

Called after the training loop exits.

Parameters:
  • step (int)

  • tokens_seen (int)

Return type:

None

class kempnerforge.training.hooks.HookRunner[source]

Bases: object

Dispatches hook calls to registered hooks. Zero cost when empty.

__init__(hooks=None)[source]
Parameters:

hooks (list[TrainingHook] | None)

Return type:

None

hooks: list[TrainingHook]
on_train_begin(config)[source]
Parameters:

config (JobConfig)

Return type:

None

on_step_end(ctx)[source]
Parameters:

ctx (StepContext)

Return type:

None

on_eval_end(metrics, step)[source]
Parameters:
Return type:

None

on_checkpoint_save(step, path)[source]
Parameters:
Return type:

None

on_train_end(step, tokens_seen)[source]
Parameters:
  • step (int)

  • tokens_seen (int)

Return type:

None