Source code for kempnerforge.distributed.expert_parallel

"""Expert Parallelism: partition MoE experts across an EP process group.

Each EP rank holds ``num_experts // ep_size`` experts. Tokens are shuffled
between ranks via all-to-all so every token reaches its assigned expert,
then results are returned to the originating rank.

When ``ep=1`` (default), this module is a no-op and the model runs with
all experts replicated on every rank (the pre-EP behavior).
"""

from __future__ import annotations

import logging

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh

from kempnerforge.model.moe import MoEMLP

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Differentiable all-to-all
# ---------------------------------------------------------------------------


class _AllToAll(torch.autograd.Function):
    """Differentiable wrapper around ``dist.all_to_all_single``.

    Forward sends tokens to the correct EP rank; backward reverses the
    communication (same all-to-all, swapped split/gather sizes).
    """

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        output_split_sizes: list[int],
        input_split_sizes: list[int],
        group: dist.ProcessGroup,
    ) -> torch.Tensor:
        # Save for backward: reverse the all-to-all direction
        # backward receives what forward sent, sends what forward received
        ctx.bwd_output_splits = input_split_sizes
        ctx.bwd_input_splits = output_split_sizes
        ctx.group = group
        x = x.contiguous()
        out = torch.empty(sum(output_split_sizes), *x.shape[1:], dtype=x.dtype, device=x.device)
        dist.all_to_all_single(
            out,
            x,
            output_split_sizes=output_split_sizes,
            input_split_sizes=input_split_sizes,
            group=group,
        )
        return out

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        grad_output = grad_output.contiguous()
        grad_input = torch.empty(
            sum(ctx.bwd_output_splits),
            *grad_output.shape[1:],
            dtype=grad_output.dtype,
            device=grad_output.device,
        )
        dist.all_to_all_single(
            grad_input,
            grad_output,
            output_split_sizes=ctx.bwd_output_splits,
            input_split_sizes=ctx.bwd_input_splits,
            group=ctx.group,
        )
        return grad_input, None, None, None


def _all_to_all(
    x: torch.Tensor,
    output_split_sizes: list[int],
    input_split_sizes: list[int],
    group: dist.ProcessGroup,
) -> torch.Tensor:
    return _AllToAll.apply(x, output_split_sizes, input_split_sizes, group)


# ---------------------------------------------------------------------------
# EP dispatch / combine
# ---------------------------------------------------------------------------


[docs] def ep_dispatch_and_compute( x: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor, moe: MoEMLP, ep_group: dist.ProcessGroup, local_expert_start: int, num_local_experts: int, ep_world_size: int, gradient_scale: bool = False, ) -> torch.Tensor: """All-to-all dispatch, local expert compute, all-to-all combine. Args: x: (num_tokens, dim) flattened token representations. weights: (num_tokens, top_k) routing weights. indices: (num_tokens, top_k) global expert indices. moe: The MoEMLP module. Holds either an nn.ModuleList of local experts (unpacked) or packed expert tensors already sliced to the local range. ep_group: EP process group. local_expert_start: First global expert index on this rank. num_local_experts: Number of experts on this rank. ep_world_size: Number of ranks in the EP group. Returns: output: (num_tokens, dim) weighted combination of expert outputs. """ num_tokens, top_k = indices.shape dim = x.shape[-1] # --- 1. Build per-token, per-expert-selection dispatch info --- # Expand tokens: each (token, k) pair becomes one dispatch entry flat_indices = indices.reshape(-1) # (num_tokens * top_k,) flat_weights = weights.reshape(-1) # (num_tokens * top_k,) # Token id for each entry token_ids = torch.arange(num_tokens, device=x.device).unsqueeze(1).expand(-1, top_k).reshape(-1) # Which EP rank owns each expert target_rank = flat_indices // num_local_experts # (num_tokens * top_k,) # --- 2. Sort entries by target rank (stable sort preserves order within rank) --- sort_indices = torch.argsort(target_rank, stable=True) sorted_token_ids = token_ids[sort_indices] sorted_flat_indices = flat_indices[sort_indices] sorted_flat_weights = flat_weights[sort_indices] # Gather token features in send order x_sorted = x[sorted_token_ids] # (num_tokens * top_k, dim) # Count how many entries go to each rank sorted_target_rank = target_rank[sort_indices] send_counts = torch.bincount(sorted_target_rank, minlength=ep_world_size) send_counts_list = send_counts.tolist() # --- 3. Exchange counts so each rank knows what it will receive --- recv_counts = torch.zeros_like(send_counts) dist.all_to_all_single(recv_counts, send_counts, group=ep_group) recv_counts_list = recv_counts.tolist() # --- 4. All-to-all: send tokens to expert-owning ranks --- received_tokens = _all_to_all(x_sorted, recv_counts_list, send_counts_list, ep_group) # Also exchange expert indices (as float for autograd compatibility, detached) sorted_expert_ids_float = sorted_flat_indices.float().unsqueeze(-1) received_expert_ids = torch.empty( sum(recv_counts_list), 1, dtype=torch.float32, device=x.device ) dist.all_to_all_single( received_expert_ids, sorted_expert_ids_float, output_split_sizes=recv_counts_list, input_split_sizes=send_counts_list, group=ep_group, ) received_expert_ids = received_expert_ids.squeeze(-1).long() # --- 5. Local expert computation --- # Sort received tokens by expert for grouped GEMM. from kempnerforge.model.moe import ( _GROUPED_MM_DTYPES, _HAS_GROUPED_MM, grouped_expert_forward, grouped_expert_forward_packed, ) use_grouped = _HAS_GROUPED_MM and received_tokens.dtype in _GROUPED_MM_DTYPES if use_grouped: sort_by_expert = torch.argsort(received_expert_ids, stable=True) sorted_recv = received_tokens[sort_by_expert] sorted_ids = received_expert_ids[sort_by_expert] # Map global expert IDs to local indices for bincount. local_ids = sorted_ids - local_expert_start tokens_per_expert = torch.bincount(local_ids, minlength=num_local_experts).tolist() if moe.packed_experts: local_output_sorted = grouped_expert_forward_packed( sorted_recv, tokens_per_expert, moe.up_w, moe.down_w, moe.gate_w if moe._is_swiglu else None, moe._packed_activation, ) else: local_output_sorted = grouped_expert_forward( sorted_recv, tokens_per_expert, moe.experts, ) # Unsort back to received order. unsort_by_expert = torch.argsort(sort_by_expert) local_output = local_output_sorted[unsort_by_expert] if moe.packed_experts: # grouped_mm over the full packed tensor touches every expert row โ†’ # AccumulateGrad on up_w/down_w/gate_w always fires when there is # at least one token. If zero tokens arrived locally, grouped_mm # short-circuits; add an explicit zero contribution to keep the # packed params in the autograd graph for FSDP2. if sum(tokens_per_expert) == 0: _zero = moe.up_w.sum() * 0 + moe.down_w.sum() * 0 if moe._is_swiglu: _zero = _zero + moe.gate_w.sum() * 0 local_output = local_output + _zero else: # Unpacked: grouped_expert_forward stacks per-expert Linear weights # into a temporary โ€” experts with zero tokens never appear in the # graph. Force AccumulateGrad to fire on each unused expert. for i in range(num_local_experts): if tokens_per_expert[i] == 0: for p in moe.experts[i].parameters(): local_output = local_output + p.sum() * 0 else: local_output = torch.zeros_like(received_tokens) if moe.packed_experts: any_computed = False for i in range(num_local_experts): global_expert_id = local_expert_start + i mask = received_expert_ids == global_expert_id if not mask.any(): continue local_output[mask] = moe._apply_packed_expert(received_tokens[mask], i) any_computed = True # If no local expert ran, packed params never entered the graph. # Add zero contribution so FSDP2 reduce-scatter fires. if not any_computed: _zero = moe.up_w.sum() * 0 + moe.down_w.sum() * 0 if moe._is_swiglu: _zero = _zero + moe.gate_w.sum() * 0 local_output = local_output + _zero else: unused_expert_params: list[torch.nn.Parameter] = [] for i in range(num_local_experts): global_expert_id = local_expert_start + i mask = received_expert_ids == global_expert_id if not mask.any(): unused_expert_params.extend(moe.experts[i].parameters()) continue local_output[mask] = moe.experts[i](received_tokens[mask]) # FSDP2 requires every parameter's AccumulateGrad hook to fire during # backward for reduce-scatter to complete. if unused_expert_params: _zero = sum(p.sum() for p in unused_expert_params) * 0 local_output = local_output + _zero # Per-expert gradient scaling: normalize by utilization ratio so # high-traffic experts don't dominate learning (DeepSeek-V3 Sec 3.2). if gradient_scale and received_tokens.requires_grad: total_recv = sum(recv_counts_list) if total_recv > 0: local_ids = received_expert_ids - local_expert_start tpe = torch.bincount(local_ids, minlength=num_local_experts) avg_tokens = total_recv / max(num_local_experts, 1) for i in range(num_local_experts): global_id = local_expert_start + i mask = received_expert_ids == global_id count = tpe[i].item() if count > 0: scale = avg_tokens / count local_output[mask] = local_output[mask] * scale # Keep dispatch all-to-all in the autograd graph. When all local experts # are unused, local_output has no gradient path to received_tokens, so # the dispatch _AllToAll.backward never fires on this rank. Since NCCL # matches all-to-all ops by position in the communicator, the missing # backward causes a misalignment with the peer EP rank โ†’ deadlock. # Adding a zero-valued contribution preserves the graph edge without # changing any gradient values. local_output = local_output + received_tokens.sum() * 0 # --- 6. All-to-all: return processed tokens to originating ranks --- # Reverse the all-to-all (swap send/recv counts) returned_tokens = _all_to_all(local_output, send_counts_list, recv_counts_list, ep_group) # --- 7. Unsort and weighted combine --- # returned_tokens is in the same order as x_sorted (sorted by target rank) # Unsort back to original (token_id, k) order unsort_indices = torch.argsort(sort_indices) returned_unsorted = returned_tokens[unsort_indices] weights_unsorted = sorted_flat_weights[unsort_indices] # Weighted sum per token returned_unsorted = returned_unsorted * weights_unsorted.unsqueeze(-1) output = torch.zeros(num_tokens, dim, dtype=x.dtype, device=x.device) output.scatter_add_( 0, token_ids[torch.arange(len(token_ids), device=x.device)] .unsqueeze(-1) .expand_as(returned_unsorted), returned_unsorted, ) return output
# --------------------------------------------------------------------------- # Apply EP to model # ---------------------------------------------------------------------------
[docs] def apply_expert_parallel(model: torch.nn.Module, device_mesh: DeviceMesh | None) -> None: """Partition MoE experts across the EP dimension of the DeviceMesh. For each MoEMLP in the model: - Prunes ``experts`` to the local subset for this EP rank - Stores EP metadata (group, rank, world_size, local expert range) Must be called AFTER tensor parallelism and BEFORE FSDP2. No-op when ``ep`` is not in the mesh or has size 1. """ if device_mesh is None: return if "ep" not in device_mesh.mesh_dim_names: # type: ignore[reportOperatorIssue] return ep_mesh = device_mesh["ep"] ep_size = ep_mesh.size() if ep_size <= 1: return ep_group = ep_mesh.get_group() ep_rank = ep_mesh.get_local_rank() applied = 0 for module in model.modules(): if not isinstance(module, MoEMLP): continue num_experts = module.num_experts assert num_experts % ep_size == 0, ( f"num_experts ({num_experts}) must be divisible by ep ({ep_size})" ) experts_per_rank = num_experts // ep_size start = ep_rank * experts_per_rank end = start + experts_per_rank if module.packed_experts: # Packed path: slice the stacked weight tensors along the expert dim. # Replace the Parameter (can't resize in-place) with the sliced view. # The unpacked ModuleList is already absent in packed mode. module.up_w = torch.nn.Parameter(module.up_w.data[start:end].clone().contiguous()) module.down_w = torch.nn.Parameter(module.down_w.data[start:end].clone().contiguous()) if module._is_swiglu: module.gate_w = torch.nn.Parameter( module.gate_w.data[start:end].clone().contiguous() ) else: # Unpacked path: prune the ModuleList to the local experts only. local_experts = torch.nn.ModuleList([module.experts[i] for i in range(start, end)]) module.experts = local_experts # Store EP metadata module.ep_world_size = ep_size module.ep_group = ep_group # type: ignore[reportAttributeAccessIssue] module.local_expert_start = start module.num_local_experts = experts_per_rank applied += 1 logger.info( f"Applied expert parallelism: ep_size={ep_size}, ep_rank={ep_rank}, layers={applied}" )