kempnerforge.model.router¶
MoE router implementations for KempnerForge models.
Classes
DeepSeek-V3 style sigmoid router with auxiliary-loss-free balancing. |
|
Mixtral-style softmax top-k router with auxiliary load-balancing loss. |
- class kempnerforge.model.router.SoftmaxTopKRouter[source]¶
Bases:
ModuleMixtral-style softmax top-k router with auxiliary load-balancing loss.
Each token independently selects top_k experts. Routing weights are softmax-normalized, then renormalized over the selected top_k.
Stores
aux_lossandexpert_countsas side effects of each forward call for collection by the training loop and metrics.- forward(x)[source]¶
Route tokens to experts.
- Parameters:
x (torch.Tensor) – (num_tokens, dim) — flattened token representations.
- Returns:
(num_tokens, top_k) — renormalized routing weights. expert_indices: (num_tokens, top_k) — selected expert indices.
- Return type:
weights
- class kempnerforge.model.router.SigmoidTopKRouter[source]¶
Bases:
ModuleDeepSeek-V3 style sigmoid router with auxiliary-loss-free balancing.
Uses per-expert sigmoid scoring (each expert scored independently) instead of softmax. Load balancing is maintained by a learnable
expert_biasadjusted via running EMA of expert utilization — no auxiliary loss term is added to the training loss by default.The bias adjustment nudges under-utilized experts up and over-utilized experts down, achieving balance without interfering with the main loss gradient signal.
Optional enhancements (all disabled by default for backward compatibility):
Sequence-level auxiliary loss: lightweight variance-based balance penalty (10-100x smaller than Switch Transformer’s aux loss) that complements bias balancing to prevent degenerate routing in long runs.
Adaptive bias schedule: decays or warms up the bias update rate over training to stabilize routing after initial exploration.
- __init__(dim, num_experts, top_k, bias_update_rate=0.001, ema_decay=0.99, sequence_aux_loss_weight=0.0, bias_schedule='constant')[source]¶
- forward(x)[source]¶
Route tokens to experts using sigmoid scoring.
- Parameters:
x (torch.Tensor) – (num_tokens, dim) — flattened token representations.
- Returns:
(num_tokens, top_k) — routing weights (sigmoid scores). expert_indices: (num_tokens, top_k) — selected expert indices.
- Return type:
weights