kempnerforge.distributed.setup¶
Distributed process group and DeviceMesh initialization.
Sets up torch.distributed, CUDA devices, and constructs a DeviceMesh with named parallelism dimensions based on DistributedConfig.
Functions
Clean up distributed process groups. |
|
Return (rank, local_rank, world_size) from environment variables. |
|
|
Initialize distributed training and build the DeviceMesh. |
Check if the current process is rank 0. |
- kempnerforge.distributed.setup.get_world_info()[source]¶
Return (rank, local_rank, world_size) from environment variables.
Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.
- kempnerforge.distributed.setup.is_rank_zero()[source]¶
Check if the current process is rank 0.
- Return type:
- kempnerforge.distributed.setup.init_distributed(config, seed=42)[source]¶
Initialize distributed training and build the DeviceMesh.
- Parameters:
config (DistributedConfig) – Distributed configuration with parallelism dimensions.
seed (int) – Random seed for reproducibility.
- Returns:
DeviceMesh if world_size > 1, None for single-GPU.
- Return type:
DeviceMesh | None