Source code for kempnerforge.config.distributed

"""Distributed parallelism configuration."""

from __future__ import annotations

from dataclasses import dataclass
from enum import StrEnum


[docs] class PipelineSchedule(StrEnum): schedule_1f1b = "1f1b" gpipe = "gpipe" interleaved_1f1b = "interleaved_1f1b"
[docs] @dataclass class DistributedConfig: """Parallelism dimensions and distributed settings.""" dp_shard: int = -1 # -1 -> auto (use all remaining GPUs) dp_replicate: int = 1 tp: int = 1 pp: int = 1 pp_schedule: PipelineSchedule = PipelineSchedule.schedule_1f1b cp: int = 1 ep: int = 1 # Expert parallelism degree (partitions MoE experts across ranks) nccl_timeout_sec: int = 1800 backend: str = "cpu:gloo,cuda:nccl"
[docs] def validate_world_size(self, world_size: int) -> None: """Validate that parallelism dimensions match world size.""" dp_shard = self._resolve_dp_shard(world_size) expected = self.dp_replicate * dp_shard * self.tp * self.pp * self.cp * self.ep if expected != world_size: raise ValueError( f"Parallelism dimensions ({self.dp_replicate} \u00d7 {dp_shard} \u00d7 " f"{self.tp} \u00d7 {self.pp} \u00d7 {self.cp} \u00d7 {self.ep} = {expected}) " f"do not match world_size ({world_size})" )
def _resolve_dp_shard(self, world_size: int) -> int: """Resolve dp_shard=-1 to actual value.""" if self.dp_shard > 0: return self.dp_shard other = self.dp_replicate * self.tp * self.pp * self.cp * self.ep if world_size % other != 0: raise ValueError( f"world_size ({world_size}) not divisible by dp_replicate*tp*pp*cp*ep ({other})" ) return world_size // other
[docs] def resolve(self, world_size: int) -> DistributedConfig: """Return a copy with dp_shard resolved to a concrete value.""" resolved = DistributedConfig( dp_shard=self._resolve_dp_shard(world_size), dp_replicate=self.dp_replicate, tp=self.tp, pp=self.pp, pp_schedule=self.pp_schedule, cp=self.cp, ep=self.ep, nccl_timeout_sec=self.nccl_timeout_sec, backend=self.backend, ) resolved.validate_world_size(world_size) return resolved
def __post_init__(self) -> None: if self.dp_shard == 0 or self.dp_shard < -1: raise ValueError("dp_shard must be -1 (auto) or positive") for name, val in [ ("dp_replicate", self.dp_replicate), ("tp", self.tp), ("pp", self.pp), ("cp", self.cp), ("ep", self.ep), ]: if val < 1: raise ValueError(f"{name} must be >= 1")