kempnerforge.model.moe

Mixture-of-Experts feed-forward layer for KempnerForge models.

Functions

build_moe(dim, hidden_dim, num_experts, top_k)

Build an MoE layer, composing router + experts from Registry.

grouped_expert_forward(x_sorted, ...)

Batched expert computation using torch._grouped_mm.

grouped_expert_forward_packed(x_sorted, ...)

Batched expert computation over pre-packed weights.

Classes

MoEMLP

Mixture-of-Experts feed-forward layer.

kempnerforge.model.moe.grouped_expert_forward(x_sorted, tokens_per_expert, experts)[source]

Batched expert computation using torch._grouped_mm.

Replaces the sequential expert loop with 2-3 grouped matrix multiplies (one CUDA kernel each), giving significant speedup when many experts are active.

Parameters:
  • x_sorted (torch.Tensor) – (total_tokens, dim) token features sorted by expert index.

  • tokens_per_expert (list[int]) – Number of tokens assigned to each expert, in order.

  • experts (torch.nn.ModuleList) – Expert modules whose weights are stacked for the grouped GEMM.

Returns:

(total_tokens, dim) expert outputs in the same sorted order as input.

Return type:

torch.Tensor

kempnerforge.model.moe.grouped_expert_forward_packed(x_sorted, tokens_per_expert, up_w, down_w, gate_w, activation)[source]

Batched expert computation over pre-packed weights.

Same as grouped_expert_forward but consumes packed weight tensors directly — no per-step torch.stack over an nn.ModuleList.

Parameters:
  • x_sorted (torch.Tensor) – (total_tokens, dim) token features sorted by expert index.

  • tokens_per_expert (list[int]) – Number of tokens assigned to each expert, in order.

  • up_w (torch.Tensor) – (E, dim, hidden) packed up-projection weights.

  • down_w (torch.Tensor) – (E, hidden, dim) packed down-projection weights.

  • gate_w (torch.Tensor | None) – (E, dim, hidden) packed gate weights for SwiGLU, else None.

  • activation – Activation function applied to the up-projection output when gate_w is None. SwiGLU hardcodes silu.

Returns:

(total_tokens, dim) expert outputs in the same sorted order as input.

Return type:

torch.Tensor

class kempnerforge.model.moe.MoEMLP[source]

Bases: Module

Mixture-of-Experts feed-forward layer.

Composes a router (from “router” registry) with N expert MLPs (from “mlp” registry). Drop-in replacement for dense MLP — same forward signature.

Stores aux_loss after each forward for collection by the training loop.

__init__(router, experts, shared_expert=None, capacity_factor=0.0, gradient_scale=False, packed_experts=False)[source]
Parameters:
  • router (nn.Module)

  • experts (nn.ModuleList)

  • shared_expert (nn.Module | None)

  • capacity_factor (float)

  • gradient_scale (bool)

  • packed_experts (bool)

Return type:

None

property aux_loss: torch.Tensor
property expert_counts: torch.Tensor
forward(x)[source]

Forward pass dispatching tokens to experts.

Parameters:

x (torch.Tensor) – (batch, seq_len, dim)

Returns:

(batch, seq_len, dim)

Return type:

torch.Tensor

kempnerforge.model.moe.build_moe(dim, hidden_dim, num_experts, top_k, activation='silu', router_type='softmax_topk', shared_experts=0, capacity_factor=0.0, gradient_scale=False, sequence_aux_loss_weight=0.0, bias_schedule='constant', packed_experts=False)[source]

Build an MoE layer, composing router + experts from Registry.

Parameters:
  • dim (int) – Model dimension.

  • hidden_dim (int) – Expert FFN hidden dimension.

  • num_experts (int) – Number of routed experts.

  • top_k (int) – Experts selected per token.

  • activation (str) – MLP activation (registry key).

  • router_type (str) – Router registry key.

  • shared_experts (int) – Number of shared experts (always active).

  • capacity_factor (float) – Token capacity per expert (0=unlimited, >0=cap).

  • gradient_scale (bool) – Per-expert gradient normalization.

  • sequence_aux_loss_weight (float) – Sequence-level balance loss weight (sigmoid router only).

  • bias_schedule (str) – Bias update rate schedule (sigmoid router only).

  • packed_experts (bool) – Pack expert weights into one tensor per projection.

Return type:

MoEMLP