kempnerforge.model.moma

Mixture of Modality-Aware Experts (MoMa) operator, FFN, and block.

Implements Lin et al. 2024 (“MoMa: Efficient Early-Fusion Pre-training with Mixture of Modality-Aware Experts”, arXiv:2407.21770) on top of KempnerForge’s existing VLM stack.

Architecture at a glance, per transformer layer:

  • Pre-norm Attention (the standard module, shared Q/K/V/O across modalities) running a single global SDPA over the concatenated image+text sequence.

  • Pre-norm MoMaFFN: a ModuleDict of per-modality ExpertChoiceMoE groups dispatched by modality_ids. Each group’s MoE uses Expert-Choice + Sigmoid routing (paper §2.2): each expert independently picks its top-k_e tokens by sigmoid score, and the token output is the sum across experts that selected it, weighted by their sigmoid scores (Eq. 1). Optional Gumbel-Sigmoid noise during training (Eq. 5).

Differs from MoT (also in this codebase): MoT has per-modality Q/K/V/O projections and per-modality FFN. MoMa has shared Q/K/V/O and per-modality MoE FFN groups (multiple experts per modality, learned routing within each group). Both share the residual-stream layout (image tokens prepended to text) and modality_ids tagging mechanism.

The module exposes four public symbols:

  • ExpertChoiceSigmoidRouter — per-modality gate (W_g^M), Sigmoid scoring, optional Gumbel noise, and per-expert top-k_e token selection.

  • ExpertChoiceMoE — composes a router + num_experts SwiGLU experts; forward(x) returns the sigmoid-weighted expert combination.

  • MoMaFFN — holds one ExpertChoiceMoE per modality and dispatches tokens via modality_ids.

  • MoMaBlock — pre-norm block: shared Attention + MoMaFFN.

Inference note: expert-choice routing is non-causal (each expert’s top-k_e depends on all tokens in the batch). v1 supports training only; autoregressive generation requires auxiliary routers (paper §2.4), deferred to a follow-up PR.

Classes

ExpertChoiceMoE

Expert-Choice MoE for one modality group.

ExpertChoiceSigmoidRouter

Expert-Choice + Sigmoid router for one modality group (Lin et al. 2024 §2.2).

MoMaBlock

Pre-norm transformer block: shared Attention + MoMaFFN.

MoMaFFN

Per-modality MoE FFN groups dispatched by modality_ids.

class kempnerforge.model.moma.ExpertChoiceSigmoidRouter[source]

Bases: Module

Expert-Choice + Sigmoid router for one modality group (Lin et al. 2024 §2.2).

Scoring: score = sigmoid(W_g x) per token-expert pair (independent across experts because Sigmoid does not normalize). Optional Gumbel perturbation during training: Gumbel-Sigmoid(x) = sigmoid(x + G' - G'') with independent Gumbel(0, 1) samples G', G'' (paper Eq. 5).

Selection: each expert independently picks its top-k_e tokens by score (torch.topk on the (expert, token) score matrix). This is the inverse of token-choice routing: there a token picks experts; here an expert picks tokens. A token can be picked by 0, 1, or more experts (the residual stream carries the unmodified token through when no expert picks it).

capacity_factor controls k_e as k_e = ceil(c_e * N) where N is the number of tokens of this modality in the current batch. The paper’s default c_e = 1/|E^M| gives k_e N/|E^M| so each expert sees the average load (perfect balance under EC routing).

__init__(dim, num_experts, capacity_factor, gumbel_noise=True)[source]
Parameters:
Return type:

None

forward(x)[source]

Route tokens to experts via expert-choice.

Parameters:

x (torch.Tensor) – (N, D) token representations for one modality group.

Returns:

(E, k_e) sigmoid scores of the tokens each

expert selected.

topk_indices: (E, k_e) token indices into x that

each expert selected. k_e is computed from capacity_factor * N, capped by N.

Return type:

topk_scores

class kempnerforge.model.moma.ExpertChoiceMoE[source]

Bases: Module

Expert-Choice MoE for one modality group.

Composes an ExpertChoiceSigmoidRouter with num_experts SwiGLU expert MLPs. Forward: each expert selects top-k_e tokens, runs its MLP on those tokens, and contributes sigmoid_score * MLP(x) to the output. Tokens not picked by any expert receive zero contribution from this MoE block (the outer residual skip preserves them).

State-dict layout (FQN-stable):

router.gate.weight    # (num_experts, dim) — gate Linear
experts.0.gate_proj.weight
experts.0.up_proj.weight
experts.0.down_proj.weight
experts.1...
...
__init__(dim, hidden_dim, num_experts, capacity_factor, activation='silu', gumbel_noise=True)[source]
Parameters:
  • dim (int)

  • hidden_dim (int)

  • num_experts (int)

  • capacity_factor (float)

  • activation (str)

  • gumbel_noise (bool)

Return type:

None

property expert_counts: torch.Tensor

Per-expert token count from the most recent forward (metrics).

forward(x)[source]

Expert-choice MoE forward over one modality group.

Parameters:

x (torch.Tensor) – (N, D) token representations.

Returns:

(N, D) output where each token has accumulated weighted outputs from every expert that selected it (zero contribution from this block when no expert selected the token).

Return type:

torch.Tensor

class kempnerforge.model.moma.MoMaFFN[source]

Bases: Module

Per-modality MoE FFN groups dispatched by modality_ids.

Holds one ExpertChoiceMoE per modality (keys: modality name). Forward dispatches tokens by modality_ids (level-1 deterministic routing), runs each modality’s EC-MoE (level-2 learned routing), then scatters per-modality outputs back to their original positions.

Modality index convention: self.modalities[i] corresponds to modality_ids == i. With the default ("image", "text"), modality_ids == 0 selects the image expert group and modality_ids == 1 selects the text expert group.

__init__(config, modalities, experts_per_modality, capacity_factor_per_modality, gumbel_noise=True)[source]
Parameters:
Return type:

None

forward(x, modality_ids)[source]

Dispatch tokens by modality and run per-modality EC-MoE.

Parameters:
  • x (torch.Tensor) – (B, S, D) residual stream.

  • modality_ids (torch.Tensor) – (B, S) long tensor. modality_ids == i routes that token to self.modalities[i]’s expert group.

Returns:

(B, S, D) tensor with each modality’s positions filled by its EC-MoE output. Positions whose modality has no tokens assigned by any expert get zeros (the outer residual skip preserves them).

Return type:

torch.Tensor

class kempnerforge.model.moma.MoMaBlock[source]

Bases: Module

Pre-norm transformer block: shared Attention + MoMaFFN.

Operates on a single residual tensor (B, S, D) like the dense TransformerBlock (unlike MoTBlock which operates on a per-modality dict of streams). The only structural difference from TransformerBlock is the FFN: MoMaFFN instead of a dense MLP or a flat MoE.

State-dict layout:

attention_norm.weight
attention.q_proj.weight
attention.k_proj.weight
attention.v_proj.weight
attention.o_proj.weight
# qk_norm only:
attention.q_norm.weight
attention.k_norm.weight
mlp_norm.weight
mlp.experts.{m}.router.gate.weight
mlp.experts.{m}.experts.0.gate_proj.weight
mlp.experts.{m}.experts.0.up_proj.weight
mlp.experts.{m}.experts.0.down_proj.weight
mlp.experts.{m}.experts.1...
...
__init__(config, modalities, experts_per_modality, capacity_factor_per_modality, gumbel_noise, layer_idx)[source]
Parameters:
Return type:

None

forward(x, rope_cos, rope_sin, modality_ids, *, doc_ids=None)[source]
Parameters:
Return type:

torch.Tensor