"""Mixture-of-Experts feed-forward layer for KempnerForge models."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import kempnerforge.model.router # noqa: F401 — triggers router registration
from kempnerforge.config.registry import registry
from kempnerforge.model.mlp import build_mlp
_HAS_GROUPED_MM = hasattr(torch, "_grouped_mm")
# torch._grouped_mm under torch.compile requires bf16/fp16 inputs
# (the meta registration rejects fp32). Guard against this.
_GROUPED_MM_DTYPES = {torch.bfloat16, torch.float16}
[docs]
def grouped_expert_forward(
x_sorted: torch.Tensor,
tokens_per_expert: list[int],
experts: nn.ModuleList,
) -> torch.Tensor:
"""Batched expert computation using ``torch._grouped_mm``.
Replaces the sequential expert loop with 2-3 grouped matrix multiplies
(one CUDA kernel each), giving significant speedup when many experts are
active.
Args:
x_sorted: (total_tokens, dim) token features sorted by expert index.
tokens_per_expert: Number of tokens assigned to each expert, in order.
experts: Expert modules whose weights are stacked for the grouped GEMM.
Returns:
(total_tokens, dim) expert outputs in the same sorted order as input.
"""
num_experts = len(experts)
total_tokens, dim = x_sorted.shape
max_tokens = max(tokens_per_expert)
if max_tokens == 0 or total_tokens == 0:
return torch.zeros_like(x_sorted)
is_swiglu = hasattr(experts[0], "gate_proj")
# Stack expert weights into (E, in, out) for grouped matmul.
# nn.Linear stores weight as (out, in), so transpose to (in, out).
up_w = torch.stack([e.up_proj.weight.t() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue] # (E, dim, H)
down_w = torch.stack([e.down_proj.weight.t() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue] # (E, H, dim)
if is_swiglu:
gate_w = torch.stack([e.gate_proj.weight.t() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue] # (E, dim, H)
# Pad token groups into (E, max_tokens, dim) for uniform batch size.
x_padded = x_sorted.new_zeros(num_experts, max_tokens, dim)
offset = 0
for i, count in enumerate(tokens_per_expert):
if count > 0:
x_padded[i, :count] = x_sorted[offset : offset + count]
offset += count
# Grouped matmuls — 3 for SwiGLU, 2 for StandardMLP.
if is_swiglu:
gate = torch._grouped_mm(x_padded, gate_w) # (E, M, H)
up = torch._grouped_mm(x_padded, up_w) # (E, M, H)
hidden = F.silu(gate) * up # (E, M, H)
else:
hidden = torch._grouped_mm(x_padded, up_w) # (E, M, H)
act_fn = experts[0]._activation
hidden = act_fn(hidden) # type: ignore[reportCallIssue]
out_padded = torch._grouped_mm(hidden, down_w) # (E, M, dim)
# Unpad back to flat sorted order.
output = torch.zeros_like(x_sorted)
offset = 0
for i, count in enumerate(tokens_per_expert):
if count > 0:
output[offset : offset + count] = out_padded[i, :count]
offset += count
return output
[docs]
def grouped_expert_forward_packed(
x_sorted: torch.Tensor,
tokens_per_expert: list[int],
up_w: torch.Tensor,
down_w: torch.Tensor,
gate_w: torch.Tensor | None,
activation,
) -> torch.Tensor:
"""Batched expert computation over pre-packed weights.
Same as ``grouped_expert_forward`` but consumes packed weight tensors
directly — no per-step ``torch.stack`` over an ``nn.ModuleList``.
Args:
x_sorted: (total_tokens, dim) token features sorted by expert index.
tokens_per_expert: Number of tokens assigned to each expert, in order.
up_w: (E, dim, hidden) packed up-projection weights.
down_w: (E, hidden, dim) packed down-projection weights.
gate_w: (E, dim, hidden) packed gate weights for SwiGLU, else None.
activation: Activation function applied to the up-projection output
when ``gate_w`` is None. SwiGLU hardcodes silu.
Returns:
(total_tokens, dim) expert outputs in the same sorted order as input.
"""
num_experts = up_w.shape[0]
total_tokens, dim = x_sorted.shape
max_tokens = max(tokens_per_expert)
if max_tokens == 0 or total_tokens == 0:
return torch.zeros_like(x_sorted)
# Pad token groups into (E, max_tokens, dim) for uniform batch size.
x_padded = x_sorted.new_zeros(num_experts, max_tokens, dim)
offset = 0
for i, count in enumerate(tokens_per_expert):
if count > 0:
x_padded[i, :count] = x_sorted[offset : offset + count]
offset += count
# Grouped matmuls — 3 for SwiGLU, 2 for StandardMLP.
if gate_w is not None:
gate = torch._grouped_mm(x_padded, gate_w) # (E, M, H)
up = torch._grouped_mm(x_padded, up_w) # (E, M, H)
hidden = F.silu(gate) * up # (E, M, H)
else:
hidden = torch._grouped_mm(x_padded, up_w) # (E, M, H)
hidden = activation(hidden)
out_padded = torch._grouped_mm(hidden, down_w) # (E, M, dim)
# Unpad back to flat sorted order.
output = torch.zeros_like(x_sorted)
offset = 0
for i, count in enumerate(tokens_per_expert):
if count > 0:
output[offset : offset + count] = out_padded[i, :count]
offset += count
return output
def _apply_capacity(
weights: torch.Tensor,
indices: torch.Tensor,
num_experts: int,
capacity_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Zero routing weights for tokens that exceed per-expert capacity.
Capacity = ceil(num_tokens * top_k / num_experts * capacity_factor).
For each expert, only the first ``capacity`` tokens (in sequence order)
are kept; the rest get weight=0 and are effectively dropped.
Args:
weights: (num_tokens, top_k) routing weights.
indices: (num_tokens, top_k) expert indices.
num_experts: Total number of experts.
capacity_factor: Multiplier for capacity (1.0 = exact average, 1.25 = 25% headroom).
Returns:
(weights, indices) with overflow entries zeroed out. Tensors are cloned
to avoid mutating the router's output.
"""
import math
num_tokens, top_k = indices.shape
capacity = max(1, math.ceil(num_tokens * top_k / num_experts * capacity_factor))
weights = weights.clone()
for k in range(top_k):
# Count per-expert assignments in this top_k slot.
for e in range(num_experts):
assigned = (indices[:, k] == e).nonzero(as_tuple=True)[0]
if assigned.numel() <= capacity:
continue
drop = assigned[capacity:]
weights[drop, k] = 0.0
return weights, indices
[docs]
class MoEMLP(nn.Module):
"""Mixture-of-Experts feed-forward layer.
Composes a router (from "router" registry) with N expert MLPs (from "mlp"
registry). Drop-in replacement for dense MLP — same forward signature.
Stores aux_loss after each forward for collection by the training loop.
"""
[docs]
def __init__(
self,
router: nn.Module,
experts: nn.ModuleList,
shared_expert: nn.Module | None = None,
capacity_factor: float = 0.0,
gradient_scale: bool = False,
packed_experts: bool = False,
) -> None:
super().__init__()
self.router = router
self.shared_expert = shared_expert
self.num_experts = len(experts)
self.capacity_factor = capacity_factor
self.gradient_scale = gradient_scale
self.packed_experts = packed_experts
# EP attributes — set by apply_expert_parallel(); defaults = no EP
self.ep_world_size: int = 1
self.ep_group = None
self.local_expert_start: int = 0
self.num_local_experts: int = len(experts)
if packed_experts:
# Packed expert weights: stack per-expert (out, in) Linear weights into
# (E, in, out) tensors so grouped GEMM can consume them zero-copy.
# Drop the per-expert nn.ModuleList — the packed tensors are the
# sole source of truth. Tests / EP / FSDP2 read `self.up_w` etc.
self._is_swiglu = hasattr(experts[0], "gate_proj")
self.up_w = nn.Parameter(
torch.stack([e.up_proj.weight.t().contiguous() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue]
)
self.down_w = nn.Parameter(
torch.stack([e.down_proj.weight.t().contiguous() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue]
)
if self._is_swiglu:
self.gate_w = nn.Parameter(
torch.stack([e.gate_proj.weight.t().contiguous() for e in experts]) # type: ignore[reportCallIssue, reportAttributeAccessIssue]
)
self._packed_activation = F.silu
else:
self._packed_activation = experts[0]._activation
else:
self.experts = experts
def _apply_packed_expert(self, x: torch.Tensor, i: int) -> torch.Tensor:
"""Apply packed expert ``i`` to ``x`` without grouped GEMM.
Used by the sequential fallback path. Matches the unpacked
SwiGLU/StandardMLP forward exactly (no bias, same matmul order).
"""
up = x @ self.up_w[i]
if self._is_swiglu:
gate = x @ self.gate_w[i]
hidden = F.silu(gate) * up
else:
hidden = self._packed_activation(up) # type: ignore[reportCallIssue]
return hidden @ self.down_w[i]
def _local_forward(
self,
x_flat: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) -> torch.Tensor:
"""Dispatch tokens to local experts and weighted-combine results.
Uses grouped GEMM when available (torch._grouped_mm) for batched expert
computation. Falls back to sequential loop otherwise.
"""
num_tokens, dim = x_flat.shape
top_k = indices.shape[1]
use_grouped = _HAS_GROUPED_MM and x_flat.dtype in _GROUPED_MM_DTYPES
if use_grouped:
# Expand (token, k) pairs → flat entries sorted by expert.
flat_indices = indices.reshape(-1) # (T*K,)
flat_weights = weights.reshape(-1) # (T*K,)
token_ids = (
torch.arange(num_tokens, device=x_flat.device)
.unsqueeze(1)
.expand(-1, top_k)
.reshape(-1)
)
sort_order = torch.argsort(flat_indices, stable=True)
sorted_expert_ids = flat_indices[sort_order]
sorted_token_ids = token_ids[sort_order]
sorted_weights = flat_weights[sort_order]
x_sorted = x_flat[sorted_token_ids]
tokens_per_expert = torch.bincount(
sorted_expert_ids, minlength=self.num_experts
).tolist()
if self.packed_experts:
expert_out = grouped_expert_forward_packed(
x_sorted,
tokens_per_expert,
self.up_w,
self.down_w,
self.gate_w if self._is_swiglu else None,
self._packed_activation,
)
else:
expert_out = grouped_expert_forward(x_sorted, tokens_per_expert, self.experts)
# Per-expert gradient scaling: normalize by utilization ratio so
# high-traffic experts don't dominate learning (DeepSeek-V3 Sec 3.2).
if self.gradient_scale and self.training:
total_assignments = sum(tokens_per_expert)
avg_tokens = total_assignments / max(self.num_experts, 1)
offset = 0
for count in tokens_per_expert:
if count > 0:
scale = avg_tokens / count
expert_out[offset : offset + count] = (
expert_out[offset : offset + count] * scale
)
offset += count
# Weighted scatter-add back to output.
expert_out = expert_out * sorted_weights.unsqueeze(-1)
output = torch.zeros(num_tokens, dim, dtype=x_flat.dtype, device=x_flat.device)
output.scatter_add_(
0,
sorted_token_ids.unsqueeze(-1).expand_as(expert_out),
expert_out,
)
else:
output = torch.zeros_like(x_flat)
# Precompute average tokens for gradient scaling
if self.gradient_scale and self.training:
avg_tokens = num_tokens * top_k / max(self.num_experts, 1)
for i in range(self.num_experts):
mask = (indices == i).any(dim=-1)
if not mask.any():
continue
expert_input = x_flat[mask]
expert_output = (
self._apply_packed_expert(expert_input, i)
if self.packed_experts
else self.experts[i](expert_input)
)
# Per-expert gradient scaling (DeepSeek-V3 Sec 3.2)
if self.gradient_scale and self.training:
tokens_i = (indices == i).sum().detach().float()
scale = avg_tokens / tokens_i.clamp(min=1.0)
expert_output = expert_output * scale
weight_for_i = (weights * (indices == i).float()).sum(dim=-1)
output[mask] += weight_for_i[mask].unsqueeze(-1) * expert_output
return output
@property
def aux_loss(self) -> torch.Tensor:
return self.router.aux_loss # type: ignore[reportReturnType]
@property
def expert_counts(self) -> torch.Tensor:
return self.router.expert_counts # type: ignore[reportReturnType]
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass dispatching tokens to experts.
Args:
x: (batch, seq_len, dim)
Returns:
(batch, seq_len, dim)
"""
B, L, D = x.shape
x_flat = x.view(B * L, D)
# Route tokens → stores aux_loss as side effect
weights, indices = self.router(x_flat)
# Capacity factor: cap tokens per expert, drop overflow.
# Dropped tokens get zero routing weight → no expert contribution,
# carried through unchanged by the residual connection.
if self.capacity_factor > 0:
weights, indices = _apply_capacity(
weights,
indices,
self.num_experts,
self.capacity_factor,
)
if self.ep_world_size > 1:
from kempnerforge.distributed.expert_parallel import ep_dispatch_and_compute
output = ep_dispatch_and_compute(
x_flat,
weights,
indices,
self,
self.ep_group, # type: ignore[reportArgumentType]
self.local_expert_start,
self.num_local_experts,
self.ep_world_size,
gradient_scale=self.gradient_scale,
)
else:
output = self._local_forward(x_flat, weights, indices)
if self.shared_expert is not None:
output = output + self.shared_expert(x_flat)
return output.view(B, L, D)
[docs]
def build_moe(
dim: int,
hidden_dim: int,
num_experts: int,
top_k: int,
activation: str = "silu",
router_type: str = "softmax_topk",
shared_experts: int = 0,
capacity_factor: float = 0.0,
gradient_scale: bool = False,
sequence_aux_loss_weight: float = 0.0,
bias_schedule: str = "constant",
packed_experts: bool = False,
) -> MoEMLP:
"""Build an MoE layer, composing router + experts from Registry.
Args:
dim: Model dimension.
hidden_dim: Expert FFN hidden dimension.
num_experts: Number of routed experts.
top_k: Experts selected per token.
activation: MLP activation (registry key).
router_type: Router registry key.
shared_experts: Number of shared experts (always active).
capacity_factor: Token capacity per expert (0=unlimited, >0=cap).
gradient_scale: Per-expert gradient normalization.
sequence_aux_loss_weight: Sequence-level balance loss weight (sigmoid router only).
bias_schedule: Bias update rate schedule (sigmoid router only).
packed_experts: Pack expert weights into one tensor per projection.
"""
router_builder = registry.get("router", router_type)
router_kwargs: dict[str, object] = {}
if router_type == "sigmoid_topk":
router_kwargs = {
"sequence_aux_loss_weight": sequence_aux_loss_weight,
"bias_schedule": bias_schedule,
}
router = router_builder(dim, num_experts, top_k, **router_kwargs)
experts = nn.ModuleList([build_mlp(dim, hidden_dim, activation) for _ in range(num_experts)])
shared_expert = None
if shared_experts > 0:
shared_expert = build_mlp(dim, hidden_dim, activation)
return MoEMLP(
router,
experts,
shared_expert,
capacity_factor=capacity_factor,
gradient_scale=gradient_scale,
packed_experts=packed_experts,
)