kempnerforge.distributed.setup

Distributed process group and DeviceMesh initialization.

Sets up torch.distributed, CUDA devices, and constructs a DeviceMesh with named parallelism dimensions based on DistributedConfig.

Functions

destroy_distributed()

Clean up distributed process groups.

get_world_info()

Return (rank, local_rank, world_size) from environment variables.

init_distributed(config[, seed])

Initialize distributed training and build the DeviceMesh.

is_rank_zero()

Check if the current process is rank 0.

kempnerforge.distributed.setup.get_world_info()[source]

Return (rank, local_rank, world_size) from environment variables.

Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.

Return type:

tuple[int, int, int]

kempnerforge.distributed.setup.is_rank_zero()[source]

Check if the current process is rank 0.

Return type:

bool

kempnerforge.distributed.setup.init_distributed(config, seed=42)[source]

Initialize distributed training and build the DeviceMesh.

Parameters:
  • config (DistributedConfig) – Distributed configuration with parallelism dimensions.

  • seed (int) – Random seed for reproducibility.

Returns:

DeviceMesh if world_size > 1, None for single-GPU.

Return type:

DeviceMesh | None

kempnerforge.distributed.setup.destroy_distributed()[source]

Clean up distributed process groups.

Return type:

None