kempnerforge.distributed.expert_parallel¶
Expert Parallelism: partition MoE experts across an EP process group.
Each EP rank holds num_experts // ep_size experts. Tokens are shuffled
between ranks via all-to-all so every token reaches its assigned expert,
then results are returned to the originating rank.
When ep=1 (default), this module is a no-op and the model runs with
all experts replicated on every rank (the pre-EP behavior).
Functions
|
Partition MoE experts across the EP dimension of the DeviceMesh. |
|
All-to-all dispatch, local expert compute, all-to-all combine. |
- kempnerforge.distributed.expert_parallel.ep_dispatch_and_compute(x, weights, indices, moe, ep_group, local_expert_start, num_local_experts, ep_world_size, gradient_scale=False)[source]¶
All-to-all dispatch, local expert compute, all-to-all combine.
- Parameters:
x (torch.Tensor) – (num_tokens, dim) flattened token representations.
weights (torch.Tensor) – (num_tokens, top_k) routing weights.
indices (torch.Tensor) – (num_tokens, top_k) global expert indices.
moe (MoEMLP) – The MoEMLP module. Holds either an nn.ModuleList of local experts (unpacked) or packed expert tensors already sliced to the local range.
ep_group (torch.distributed.ProcessGroup) – EP process group.
local_expert_start (int) – First global expert index on this rank.
num_local_experts (int) – Number of experts on this rank.
ep_world_size (int) – Number of ranks in the EP group.
gradient_scale (bool)
- Returns:
(num_tokens, dim) weighted combination of expert outputs.
- Return type:
output
- kempnerforge.distributed.expert_parallel.apply_expert_parallel(model, device_mesh)[source]¶
Partition MoE experts across the EP dimension of the DeviceMesh.
- For each MoEMLP in the model:
Prunes
expertsto the local subset for this EP rankStores EP metadata (group, rank, world_size, local expert range)
Must be called AFTER tensor parallelism and BEFORE FSDP2.
No-op when
epis not in the mesh or has size 1.- Parameters:
model (torch.nn.Module)
device_mesh (DeviceMesh | None)
- Return type:
None