Source code for kempnerforge.training.grad

"""Gradient utilities for distributed training.

Handles gradient accumulation with no_sync context for skipping
redundant all-reduces during micro-batching.
"""

from __future__ import annotations

import contextlib
from collections.abc import Generator

import torch


[docs] @contextlib.contextmanager def maybe_no_sync( model: torch.nn.Module, micro_step: int, grad_accum_steps: int, ) -> Generator[None, None, None]: """Skip gradient sync on intermediate accumulation steps. On the last micro-step, gradients are synchronized normally. On earlier micro-steps, sync is skipped to avoid redundant all-reduces. Works with FSDP2 (which implements ``no_sync()`` on the module). For non-distributed models, this is a no-op. """ is_last_step = (micro_step + 1) == grad_accum_steps if is_last_step or not hasattr(model, "set_requires_gradient_sync"): yield else: # FSDP2 uses set_requires_gradient_sync instead of no_sync() context model.set_requires_gradient_sync(False) # type: ignore[reportCallIssue] try: yield finally: model.set_requires_gradient_sync(True) # type: ignore[reportCallIssue]