Expert parallelism¶
Expert parallelism (EP) partitions MoE experts across an EP process
group — each rank holds num_experts / ep experts and tokens are
shuffled between ranks by all-to-all so every token reaches its
assigned expert. With ep=1 (the default), experts are replicated on
every rank and EP is a no-op.
Entry points:
apply_expert_parallel(model, device_mesh)— prunes eachMoEMLPto the local expert slice and stores EP metadata (ep_group,ep_world_size,local_expert_start,num_local_experts).ep_dispatch_and_compute(x, weights, indices, moe, ...)— runs insideMoEMLP.forward()whenep_world_size > 1and implements the all-to-all dispatch / local compute / all-to-all combine.
When EP kicks in¶
# kempnerforge/model/moe.py, MoEMLP.forward
if self.ep_world_size > 1:
output = ep_dispatch_and_compute(
x, weights, indices, self,
self.ep_group, self.local_expert_start,
self.num_local_experts, self.ep_world_size,
gradient_scale=self.gradient_scale,
)
With ep=1, ep_world_size stays at 1 (the default set in
MoEMLP.__init__) and the forward path runs experts locally. With
ep>1, apply_expert_parallel bumps ep_world_size to the EP mesh
size and populates the other metadata.
Dispatch / combine flow¶
ep_dispatch_and_compute is a seven-step sequence:
# |
Step |
What it does |
|---|---|---|
1 |
Expand |
|
2 |
Sort |
Stable-sort entries by target EP rank so same-destination tokens are contiguous |
3 |
Exchange counts |
|
4 |
Dispatch |
|
5 |
Local compute |
Grouped GEMM over received tokens (sorted by local expert) when |
6 |
Combine |
Reverse all-to-all sends processed tokens back to the originating ranks |
7 |
Weighted sum |
|
The dispatch all-to-all is wrapped in
_AllToAll
— a custom torch.autograd.Function whose backward is the same
all-to-all with send/recv counts swapped. That’s what makes the
forward path differentiable.
Unused-expert kludge¶
If a local expert receives zero tokens in a step, its parameters never enter the autograd graph — and FSDP2’s reduce-scatter, which fires only after every param in a unit has accumulated a gradient, hangs forever.
ep_dispatch_and_compute forces an
AccumulateGrad hook to fire on each unused expert by adding a
zero-valued sum of its parameters into the output:
for i in range(num_local_experts):
if tokens_per_expert[i] == 0:
for p in moe.experts[i].parameters():
local_output = local_output + p.sum() * 0
Similar zero-contributions handle the packed-expert path and the
case where the dispatch all-to-all would otherwise have no gradient
edge back from local_output to received_tokens (which would
cause the backward all-to-all to be skipped on one side —
positional mismatch in NCCL → deadlock).
Apply step¶
for module in model.modules():
if not isinstance(module, MoEMLP):
continue
assert num_experts % ep_size == 0
start = ep_rank * (num_experts // ep_size)
end = start + (num_experts // ep_size)
if module.packed_experts:
# replace Parameter with sliced view (can't resize in place)
module.up_w = Parameter(module.up_w.data[start:end].clone())
module.down_w = Parameter(module.down_w.data[start:end].clone())
...
else:
module.experts = ModuleList([module.experts[i] for i in range(start, end)])
module.ep_world_size = ep_size
module.ep_group = ep_group
module.local_expert_start = start
module.num_local_experts = num_experts // ep_size
The router (moe.router) is not sharded — every rank keeps the
full router weights so it can make the routing decision locally
before dispatch. Shared experts (moe.shared_expert) are also kept
on every rank.
Composition with other parallelisms¶
EP runs after TP and before FSDP2 — see Parallelism order.
EP + TP: TP shards the non-MoE Linears (attention q/k/v/o and shared-expert gate/up/down) along the
tpmesh dim. EP shards the routed experts alongep. Dense TP layers are untouched byapply_expert_parallel.EP + FSDP2: FSDP2 wraps the MoE layer’s
attentionandmlpseparately (per-sub-module wrapping) rather than the whole block — see FSDP2 § EP-MoE. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between the two EP all-to-alls in backward, deadlocking.EP + FP8: expert Linears are excluded from the Float8 pass (
"experts" in fqn → Falsein the filter). The grouped GEMM path (torch._grouped_mm) doesn’t go throughFloat8Linear.forward, so FP8 applied there is ineffective and adds surprise failures. See FP8 § Exclusion rules.
Config¶
[distributed]
ep = 2 # expert parallelism degree
[model]
num_experts = 8 # global expert count
moe_top_k = 2 # experts per token
moe_shared_experts = 1 # optional always-on expert
moe_packed_experts = false # grouped GEMM with packed weights (opt-in)
moe_gradient_scale = false # per-expert gradient normalization (opt-in)
num_experts % ep == 0 is checked at apply time. The parallelism
arithmetic (see DistributedConfig)
requires dp_replicate · dp_shard · tp · pp · cp · ep == world_size.
Example: moe_ep_32gpu.toml¶
dp_shard=4, tp=4, ep=2, pp=1
num_experts=8, moe_top_k=2
mesh: ("dp_shard", "ep", "tp") → (4, 2, 4)
Each EP group (size 2) splits the 8 experts as
experts 0-3on rank 0 andexperts 4-7on rank 1.Within each EP group, TP shards the per-expert Linears along
tp=4.FSDP shards the remaining params across the
dp_shard=4axis with per-sub-module wrapping.
Benchmark and reproducer: Benchmarks § MoE Expert Parallelism.
Gradient scaling (optional)¶
When moe_gradient_scale = true, the output of each local expert is
multiplied by avg_tokens / tokens_for_this_expert so high-traffic
experts don’t dominate the gradient. The scaling happens on
local_output before the combine all-to-all, so the adjusted
gradient flows back through the dispatch all-to-all to the router
and expert params correctly. Disabled by default — it changes
gradient magnitudes and should be validated against a baseline run
before flipping on.
See also¶
MoE overview — architecture, routers, auxiliary losses. This distributed/ page is the canonical EP reference; the MoE pages link here.
FSDP2 § EP-MoE — the per-sub-module wrapping pattern.
FP8 § Exclusion rules — why experts + router + shared_expert are skipped by
apply_float8.Parallelism order — EP’s place in the apply sequence.
Validation rules § Expert parallel — the
num_experts % ep == 0check.Benchmarks § MoE Expert Parallelism — measured EP speedup.