kempnerforge.model.moma¶
Mixture of Modality-Aware Experts (MoMa) operator, FFN, and block.
Implements Lin et al. 2024 (“MoMa: Efficient Early-Fusion Pre-training with Mixture of Modality-Aware Experts”, arXiv:2407.21770) on top of KempnerForge’s existing VLM stack.
Architecture at a glance, per transformer layer:
Pre-norm
Attention(the standard module, shared Q/K/V/O across modalities) running a single global SDPA over the concatenated image+text sequence.Pre-norm
MoMaFFN: aModuleDictof per-modalityExpertChoiceMoEgroups dispatched bymodality_ids. Each group’s MoE uses Expert-Choice + Sigmoid routing (paper §2.2): each expert independently picks its top-k_etokens by sigmoid score, and the token output is the sum across experts that selected it, weighted by their sigmoid scores (Eq. 1). Optional Gumbel-Sigmoid noise during training (Eq. 5).
Differs from MoT (also in this codebase): MoT has per-modality Q/K/V/O
projections and per-modality FFN. MoMa has shared Q/K/V/O and per-modality
MoE FFN groups (multiple experts per modality, learned routing within
each group). Both share the residual-stream layout (image tokens prepended
to text) and modality_ids tagging mechanism.
The module exposes four public symbols:
ExpertChoiceSigmoidRouter— per-modality gate (W_g^M), Sigmoid scoring, optional Gumbel noise, and per-expert top-k_etoken selection.ExpertChoiceMoE— composes a router +num_expertsSwiGLU experts; forward(x) returns the sigmoid-weighted expert combination.MoMaFFN— holds oneExpertChoiceMoEper modality and dispatches tokens viamodality_ids.MoMaBlock— pre-norm block: sharedAttention+MoMaFFN.
Inference note: expert-choice routing is non-causal (each expert’s
top-k_e depends on all tokens in the batch). v1 supports training only;
autoregressive generation requires auxiliary routers (paper §2.4), deferred
to a follow-up PR.
Classes
Expert-Choice MoE for one modality group. |
|
Expert-Choice + Sigmoid router for one modality group (Lin et al. 2024 §2.2). |
|
Pre-norm transformer block: shared |
|
Per-modality MoE FFN groups dispatched by |
- class kempnerforge.model.moma.ExpertChoiceSigmoidRouter[source]¶
Bases:
ModuleExpert-Choice + Sigmoid router for one modality group (Lin et al. 2024 §2.2).
Scoring:
score = sigmoid(W_g x)per token-expert pair (independent across experts because Sigmoid does not normalize). Optional Gumbel perturbation during training:Gumbel-Sigmoid(x) = sigmoid(x + G' - G'')with independent Gumbel(0, 1) samplesG', G''(paper Eq. 5).Selection: each expert independently picks its top-
k_etokens by score (torch.topkon the (expert, token) score matrix). This is the inverse of token-choice routing: there a token picks experts; here an expert picks tokens. A token can be picked by 0, 1, or more experts (the residual stream carries the unmodified token through when no expert picks it).capacity_factorcontrolsk_eask_e = ceil(c_e * N)whereNis the number of tokens of this modality in the current batch. The paper’s defaultc_e = 1/|E^M|givesk_e ≈ N/|E^M|so each expert sees the average load (perfect balance under EC routing).- forward(x)[source]¶
Route tokens to experts via expert-choice.
- Parameters:
x (torch.Tensor) –
(N, D)token representations for one modality group.- Returns:
(E, k_e)sigmoid scores of the tokens eachexpert selected.
topk_indices:(E, k_e)token indices intoxthateach expert selected.
k_eis computed fromcapacity_factor * N, capped byN.
- Return type:
topk_scores
- class kempnerforge.model.moma.ExpertChoiceMoE[source]¶
Bases:
ModuleExpert-Choice MoE for one modality group.
Composes an
ExpertChoiceSigmoidRouterwithnum_expertsSwiGLU expert MLPs. Forward: each expert selects top-k_etokens, runs its MLP on those tokens, and contributessigmoid_score * MLP(x)to the output. Tokens not picked by any expert receive zero contribution from this MoE block (the outer residual skip preserves them).State-dict layout (FQN-stable):
router.gate.weight # (num_experts, dim) — gate Linear experts.0.gate_proj.weight experts.0.up_proj.weight experts.0.down_proj.weight experts.1... ...
- __init__(dim, hidden_dim, num_experts, capacity_factor, activation='silu', gumbel_noise=True)[source]¶
- property expert_counts: torch.Tensor¶
Per-expert token count from the most recent forward (metrics).
- forward(x)[source]¶
Expert-choice MoE forward over one modality group.
- Parameters:
x (torch.Tensor) –
(N, D)token representations.- Returns:
(N, D)output where each token has accumulated weighted outputs from every expert that selected it (zero contribution from this block when no expert selected the token).- Return type:
- class kempnerforge.model.moma.MoMaFFN[source]¶
Bases:
ModulePer-modality MoE FFN groups dispatched by
modality_ids.Holds one
ExpertChoiceMoEper modality (keys: modality name). Forward dispatches tokens bymodality_ids(level-1 deterministic routing), runs each modality’s EC-MoE (level-2 learned routing), then scatters per-modality outputs back to their original positions.Modality index convention:
self.modalities[i]corresponds tomodality_ids == i. With the default("image", "text"),modality_ids == 0selects the image expert group andmodality_ids == 1selects the text expert group.- __init__(config, modalities, experts_per_modality, capacity_factor_per_modality, gumbel_noise=True)[source]¶
- forward(x, modality_ids)[source]¶
Dispatch tokens by modality and run per-modality EC-MoE.
- Parameters:
x (torch.Tensor) –
(B, S, D)residual stream.modality_ids (torch.Tensor) –
(B, S)long tensor.modality_ids == iroutes that token toself.modalities[i]’s expert group.
- Returns:
(B, S, D)tensor with each modality’s positions filled by its EC-MoE output. Positions whose modality has no tokens assigned by any expert get zeros (the outer residual skip preserves them).- Return type:
- class kempnerforge.model.moma.MoMaBlock[source]¶
Bases:
ModulePre-norm transformer block: shared
Attention+MoMaFFN.Operates on a single residual tensor
(B, S, D)like the denseTransformerBlock(unlikeMoTBlockwhich operates on a per-modality dict of streams). The only structural difference fromTransformerBlockis the FFN:MoMaFFNinstead of a dense MLP or a flat MoE.State-dict layout:
attention_norm.weight attention.q_proj.weight attention.k_proj.weight attention.v_proj.weight attention.o_proj.weight # qk_norm only: attention.q_norm.weight attention.k_norm.weight mlp_norm.weight mlp.experts.{m}.router.gate.weight mlp.experts.{m}.experts.0.gate_proj.weight mlp.experts.{m}.experts.0.up_proj.weight mlp.experts.{m}.experts.0.down_proj.weight mlp.experts.{m}.experts.1... ...
- __init__(config, modalities, experts_per_modality, capacity_factor_per_modality, gumbel_noise, layer_idx)[source]¶
- forward(x, rope_cos, rope_sin, modality_ids, *, doc_ids=None)[source]¶
- Parameters:
x (torch.Tensor)
rope_cos (torch.Tensor)
rope_sin (torch.Tensor)
modality_ids (torch.Tensor)
doc_ids (torch.Tensor | None)
- Return type: