kempnerforge.model.mot

Mixture-of-Transformers (MoT) operator and block.

Implements Algorithm 1 of Liang et al. (2024) “Mixture-of-Transformers” and Figure 1c of the multimodal_paper. At every layer, every modality has its own dense Q/K/V/O projections and a dedicated FFN; a single global self-attention mixes all modality streams within the layer.

The module exposes three public symbols:

  • MoTAttention — per-modality Q/K/V/O projections, one global SDPA over the concatenated multi-modality sequence.

  • MoTBlock — pre-norm block: per-modality norms + MoTAttention + per-modality FFN. Identity at construction (zero-init residual).

  • mot_warm_start_from_text_stack — copy dense TransformerBlock weights from a source state dict into every per-modality copy of every MoTBlock in a Transformer. JD/text-only -> MoT warm start.

Functions

mot_warm_start_from_text_stack(transformer, ...)

Copy dense TransformerBlock weights into every per-modality copy inside each MoTBlock in transformer.layers.

Classes

MoTAttention

Per-modality Q/K/V/O projections; one global SDPA over all modalities.

MoTBlock

Modality-aware transformer block: per-modality norms + MoTAttention + per-modality FFN.

class kempnerforge.model.mot.MoTAttention[source]

Bases: Module

Per-modality Q/K/V/O projections; one global SDPA over all modalities.

State-dict layout (per-modality nesting via nn.ModuleDict):

q_proj.{m}.weight    # (n_heads * head_dim, dim)
k_proj.{m}.weight    # (n_kv_heads * head_dim, dim)
v_proj.{m}.weight    # (n_kv_heads * head_dim, dim)
o_proj.{m}.weight    # (dim, n_heads * head_dim)
q_norm.{m}.weight    # (head_dim,) when qk_norm=True
k_norm.{m}.weight    # (head_dim,) when qk_norm=True

Initialization: each per-modality o_proj.weight is zero so the operator’s contribution to the residual stream is zero at construction (warm-start identity).

Causal mask: a single is_causal=True over the concatenated sequence. With image-then-text concatenation order this matches Chameleon-style autoregressive multimodal: image attends causally among image; text attends to all earlier image and earlier text.

__init__(dim, n_heads, n_kv_heads, modalities, head_dim=None, qk_norm=False)[source]
Parameters:
Return type:

None

forward(streams, rope, is_causal=True)[source]

Run per-modality projections, global SDPA, per-modality output.

Parameters:
  • streams (dict[str, torch.Tensor]) – per-modality input. Keys must equal the construction-time modalities. Each value has shape (batch, seq_m, dim).

  • rope (dict[str, tuple[torch.Tensor, torch.Tensor]]) – per-modality (cos, sin) RoPE tables. Each cos[m] / sin[m] has shape (seq_m, head_dim // 2) — counts from position 0 within that modality’s axis.

  • is_causal (bool) – passed through to F.scaled_dot_product_attention.

Returns:

per-modality output of shape (batch, seq_m, dim).

Return type:

dict[str, torch.Tensor]

class kempnerforge.model.mot.MoTBlock[source]

Bases: Module

Modality-aware transformer block: per-modality norms + MoTAttention + per-modality FFN.

State-dict layout (per-modality nesting):

attn_norm.{m}.weight     # RMSNorm or LayerNorm per modality
attn.q_proj.{m}.weight   # ... see MoTAttention
mlp_norm.{m}.weight
mlp.{m}.gate_proj.weight # SwiGLU per modality (or up_proj/down_proj for standard)

Initialization: per-modality mlp.{m}.down_proj.weight is zero so the FFN contribution is zero at construction. Combined with MoTAttention’s zero-init o_proj, the block is identity at construction (warm-start property: a fresh MoT block is bit-equal to passing the residual through unchanged).

MoE branches (when the layer index hits moe_frequency) skip the zero-init since MoEMLP does not have a single down_proj; identity-at-construction is not a hard requirement for MoT — the warm-start helper will overwrite these weights from a source state dict, and from-scratch runs simply train the residual.

__init__(config, modalities, layer_idx)[source]
Parameters:
Return type:

None

forward(streams, rope)[source]
Parameters:
Return type:

dict[str, torch.Tensor]

kempnerforge.model.mot.mot_warm_start_from_text_stack(transformer, source_state_dict)[source]

Copy dense TransformerBlock weights into every per-modality copy inside each MoTBlock in transformer.layers.

Use case: warm-start a fresh MoT training run from a JD or text-only checkpoint. The caller loads the source state dict (e.g., via torch.load or DCP), then calls this helper to translate dense block keys into per-modality copies.

Translation, per layer index i and per modality m:

layers.{i}.attention_norm.weight    -> layers.{i}.attn_norm.{m}.weight
layers.{i}.attention.q_proj.weight  -> layers.{i}.attn.q_proj.{m}.weight
layers.{i}.attention.k_proj.weight  -> layers.{i}.attn.k_proj.{m}.weight
layers.{i}.attention.v_proj.weight  -> layers.{i}.attn.v_proj.{m}.weight
layers.{i}.attention.o_proj.weight  -> layers.{i}.attn.o_proj.{m}.weight
# qk_norm only:
layers.{i}.attention.q_norm.weight  -> layers.{i}.attn.q_norm.{m}.weight
layers.{i}.attention.k_norm.weight  -> layers.{i}.attn.k_norm.{m}.weight
layers.{i}.mlp_norm.weight          -> layers.{i}.mlp_norm.{m}.weight
layers.{i}.mlp.<proj>.weight        -> layers.{i}.mlp.{m}.<proj>.weight
    for <proj> in {gate_proj, up_proj, down_proj}

Plus the final norm (when present on the target):

norm.weight -> mot_norms.{m}.weight

Source keys may optionally have a transformer. prefix; both are accepted. MoE FFN branches are skipped (their state dict layout is incompatible with a dense mlp.<proj> translation).

No-op when transformer has no MoTBlock layers. Idempotent: repeated calls with the same source produce identical state. Shape-checked: raises ValueError on per-tensor shape mismatch.

Parameters:
Return type:

None