kempnerforge.model.router

MoE router implementations for KempnerForge models.

Classes

SigmoidTopKRouter

DeepSeek-V3 style sigmoid router with auxiliary-loss-free balancing.

SoftmaxTopKRouter

Mixtral-style softmax top-k router with auxiliary load-balancing loss.

class kempnerforge.model.router.SoftmaxTopKRouter[source]

Bases: Module

Mixtral-style softmax top-k router with auxiliary load-balancing loss.

Each token independently selects top_k experts. Routing weights are softmax-normalized, then renormalized over the selected top_k.

Stores aux_loss and expert_counts as side effects of each forward call for collection by the training loop and metrics.

__init__(dim, num_experts, top_k)[source]
Parameters:
Return type:

None

forward(x)[source]

Route tokens to experts.

Parameters:

x (torch.Tensor) – (num_tokens, dim) — flattened token representations.

Returns:

(num_tokens, top_k) — renormalized routing weights. expert_indices: (num_tokens, top_k) — selected expert indices.

Return type:

weights

class kempnerforge.model.router.SigmoidTopKRouter[source]

Bases: Module

DeepSeek-V3 style sigmoid router with auxiliary-loss-free balancing.

Uses per-expert sigmoid scoring (each expert scored independently) instead of softmax. Load balancing is maintained by a learnable expert_bias adjusted via running EMA of expert utilization — no auxiliary loss term is added to the training loss by default.

The bias adjustment nudges under-utilized experts up and over-utilized experts down, achieving balance without interfering with the main loss gradient signal.

Optional enhancements (all disabled by default for backward compatibility):

  • Sequence-level auxiliary loss: lightweight variance-based balance penalty (10-100x smaller than Switch Transformer’s aux loss) that complements bias balancing to prevent degenerate routing in long runs.

  • Adaptive bias schedule: decays or warms up the bias update rate over training to stabilize routing after initial exploration.

__init__(dim, num_experts, top_k, bias_update_rate=0.001, ema_decay=0.99, sequence_aux_loss_weight=0.0, bias_schedule='constant')[source]
Parameters:
  • dim (int)

  • num_experts (int)

  • top_k (int)

  • bias_update_rate (float)

  • ema_decay (float)

  • sequence_aux_loss_weight (float)

  • bias_schedule (str)

Return type:

None

set_step(step, max_steps)[source]

Set current training step for adaptive bias scheduling.

Parameters:
Return type:

None

forward(x)[source]

Route tokens to experts using sigmoid scoring.

Parameters:

x (torch.Tensor) – (num_tokens, dim) — flattened token representations.

Returns:

(num_tokens, top_k) — routing weights (sigmoid scores). expert_indices: (num_tokens, top_k) — selected expert indices.

Return type:

weights