kempnerforge.training.grad¶
Gradient utilities for distributed training.
Handles gradient accumulation with no_sync context for skipping redundant all-reduces during micro-batching.
Functions
|
Skip gradient sync on intermediate accumulation steps. |
- kempnerforge.training.grad.maybe_no_sync(model, micro_step, grad_accum_steps)[source]¶
Skip gradient sync on intermediate accumulation steps.
On the last micro-step, gradients are synchronized normally. On earlier micro-steps, sync is skipped to avoid redundant all-reduces.
Works with FSDP2 (which implements
no_sync()on the module). For non-distributed models, this is a no-op.- Parameters:
model (torch.nn.Module)
micro_step (int)
grad_accum_steps (int)
- Return type:
Generator[None, None, None]