Source code for kempnerforge.distributed.utils
"""Distributed training utilities.
Gradient clipping and data-parallel helpers that work correctly with
FSDP2 and multi-dimensional parallelism.
"""
from __future__ import annotations
import logging
import torch
from torch.distributed.device_mesh import DeviceMesh
logger = logging.getLogger(__name__)
[docs]
def get_dp_info(device_mesh: DeviceMesh | None) -> tuple[int, int]:
"""Get (dp_rank, dp_size) from the device mesh, accounting for PP/TP.
Handles all DeviceMesh configurations: HSDP (dp_replicate + dp_shard),
pure FSDP (dp_shard only), replicate-only, or no mesh (single GPU).
Args:
device_mesh: Full DeviceMesh, or None for single-GPU.
Returns:
Tuple of (dp_rank, dp_world_size).
"""
if device_mesh is None:
return 0, 1
dim_names = device_mesh.mesh_dim_names
if "dp_shard" in dim_names and "dp_replicate" in dim_names: # type: ignore[reportOperatorIssue]
dp_mesh = device_mesh["dp_replicate", "dp_shard"]
return dp_mesh.get_local_rank(), dp_mesh.size()
elif "dp_shard" in dim_names: # type: ignore[reportOperatorIssue]
dp_mesh = device_mesh["dp_shard"]
return dp_mesh.get_local_rank(), dp_mesh.size()
elif "dp_replicate" in dim_names: # type: ignore[reportOperatorIssue]
dp_mesh = device_mesh["dp_replicate"]
return dp_mesh.get_local_rank(), dp_mesh.size()
return 0, 1
[docs]
def clip_grad_norm_(
model: torch.nn.Module,
max_norm: float,
foreach: bool = True,
) -> torch.Tensor:
"""Clip gradient norm across all parameters.
Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on
different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only
norms on (dp_shard)). Groups gradients by mesh, computes per-group norms
via stack, then combines across groups.
Falls back to ``torch.nn.utils.clip_grad_norm_`` when there is only one mesh
(pure FSDP or single-GPU).
Args:
model: Model whose gradients to clip.
max_norm: Maximum gradient norm.
foreach: Use the faster foreach implementation.
Returns:
Total gradient norm (before clipping).
"""
grads = [p.grad for p in model.parameters() if p.grad is not None]
if not grads:
return torch.tensor(0.0)
# Check if all gradients share the same mesh (or are plain tensors).
# If so, use the fast standard path.
def _mesh_key(g: torch.Tensor) -> int:
spec = getattr(g, "_spec", None)
return id(spec.mesh) if spec is not None else 0
mesh_keys = {_mesh_key(g) for g in grads}
if len(mesh_keys) <= 1:
# Single mesh โ standard path works
return torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=max_norm,
foreach=foreach,
)
# Mixed meshes โ group by mesh, compute per-group norms, combine.
from collections import defaultdict
groups: dict[int, list[torch.Tensor]] = defaultdict(list)
for g in grads:
groups[_mesh_key(g)].append(g.detach())
total_norm_sq = torch.tensor(0.0, device=grads[0].device)
for group_grads in groups.values():
norms = torch.stack([g.norm(2.0) for g in group_grads])
group_norm_sq = norms.pow(2).sum()
# DTensor norm is a partial sum โ full_tensor() does the all-reduce
if hasattr(group_norm_sq, "full_tensor"):
group_norm_sq = group_norm_sq.full_tensor() # type: ignore[reportAttributeAccessIssue]
total_norm_sq = total_norm_sq + group_norm_sq
total_norm = total_norm_sq.sqrt()
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.mul_(clip_coef_clamped)
return total_norm