kempnerforge.distributed.utils¶
Distributed training utilities.
Gradient clipping and data-parallel helpers that work correctly with FSDP2 and multi-dimensional parallelism.
Functions
|
Clip gradient norm across all parameters. |
|
Get (dp_rank, dp_size) from the device mesh, accounting for PP/TP. |
- kempnerforge.distributed.utils.get_dp_info(device_mesh)[source]¶
Get (dp_rank, dp_size) from the device mesh, accounting for PP/TP.
Handles all DeviceMesh configurations: HSDP (dp_replicate + dp_shard), pure FSDP (dp_shard only), replicate-only, or no mesh (single GPU).
- kempnerforge.distributed.utils.clip_grad_norm_(model, max_norm, foreach=True)[source]¶
Clip gradient norm across all parameters.
Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.
Falls back to
torch.nn.utils.clip_grad_norm_when there is only one mesh (pure FSDP or single-GPU).- Parameters:
model (torch.nn.Module) – Model whose gradients to clip.
max_norm (float) – Maximum gradient norm.
foreach (bool) – Use the faster foreach implementation.
- Returns:
Total gradient norm (before clipping).
- Return type: