kempnerforge.distributed.utils

Distributed training utilities.

Gradient clipping and data-parallel helpers that work correctly with FSDP2 and multi-dimensional parallelism.

Functions

clip_grad_norm_(model, max_norm[, foreach])

Clip gradient norm across all parameters.

get_dp_info(device_mesh)

Get (dp_rank, dp_size) from the device mesh, accounting for PP/TP.

kempnerforge.distributed.utils.get_dp_info(device_mesh)[source]

Get (dp_rank, dp_size) from the device mesh, accounting for PP/TP.

Handles all DeviceMesh configurations: HSDP (dp_replicate + dp_shard), pure FSDP (dp_shard only), replicate-only, or no mesh (single GPU).

Parameters:

device_mesh (DeviceMesh | None) – Full DeviceMesh, or None for single-GPU.

Returns:

Tuple of (dp_rank, dp_world_size).

Return type:

tuple[int, int]

kempnerforge.distributed.utils.clip_grad_norm_(model, max_norm, foreach=True)[source]

Clip gradient norm across all parameters.

Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.

Falls back to torch.nn.utils.clip_grad_norm_ when there is only one mesh (pure FSDP or single-GPU).

Parameters:
  • model (torch.nn.Module) – Model whose gradients to clip.

  • max_norm (float) – Maximum gradient norm.

  • foreach (bool) – Use the faster foreach implementation.

Returns:

Total gradient norm (before clipping).

Return type:

torch.Tensor