kempnerforge.training.grad

Gradient utilities for distributed training.

Handles gradient accumulation with no_sync context for skipping redundant all-reduces during micro-batching.

Functions

maybe_no_sync(model, micro_step, ...)

Skip gradient sync on intermediate accumulation steps.

kempnerforge.training.grad.maybe_no_sync(model, micro_step, grad_accum_steps)[source]

Skip gradient sync on intermediate accumulation steps.

On the last micro-step, gradients are synchronized normally. On earlier micro-steps, sync is skipped to avoid redundant all-reduces.

Works with FSDP2 (which implements no_sync() on the module). For non-distributed models, this is a no-op.

Parameters:
Return type:

Generator[None, None, None]