kempnerforge.distributed.expert_parallel

Expert Parallelism: partition MoE experts across an EP process group.

Each EP rank holds num_experts // ep_size experts. Tokens are shuffled between ranks via all-to-all so every token reaches its assigned expert, then results are returned to the originating rank.

When ep=1 (default), this module is a no-op and the model runs with all experts replicated on every rank (the pre-EP behavior).

Functions

apply_expert_parallel(model, device_mesh)

Partition MoE experts across the EP dimension of the DeviceMesh.

ep_dispatch_and_compute(x, weights, indices, ...)

All-to-all dispatch, local expert compute, all-to-all combine.

kempnerforge.distributed.expert_parallel.ep_dispatch_and_compute(x, weights, indices, moe, ep_group, local_expert_start, num_local_experts, ep_world_size, gradient_scale=False)[source]

All-to-all dispatch, local expert compute, all-to-all combine.

Parameters:
  • x (torch.Tensor) – (num_tokens, dim) flattened token representations.

  • weights (torch.Tensor) – (num_tokens, top_k) routing weights.

  • indices (torch.Tensor) – (num_tokens, top_k) global expert indices.

  • moe (MoEMLP) – The MoEMLP module. Holds either an nn.ModuleList of local experts (unpacked) or packed expert tensors already sliced to the local range.

  • ep_group (torch.distributed.ProcessGroup) – EP process group.

  • local_expert_start (int) – First global expert index on this rank.

  • num_local_experts (int) – Number of experts on this rank.

  • ep_world_size (int) – Number of ranks in the EP group.

  • gradient_scale (bool)

Returns:

(num_tokens, dim) weighted combination of expert outputs.

Return type:

output

kempnerforge.distributed.expert_parallel.apply_expert_parallel(model, device_mesh)[source]

Partition MoE experts across the EP dimension of the DeviceMesh.

For each MoEMLP in the model:
  • Prunes experts to the local subset for this EP rank

  • Stores EP metadata (group, rank, world_size, local expert range)

Must be called AFTER tensor parallelism and BEFORE FSDP2.

No-op when ep is not in the mesh or has size 1.

Parameters:
Return type:

None