kempnerforge.model.mot¶
Mixture-of-Transformers (MoT) operator and block.
Implements Algorithm 1 of Liang et al. (2024) “Mixture-of-Transformers” and Figure 1c of the multimodal_paper. At every layer, every modality has its own dense Q/K/V/O projections and a dedicated FFN; a single global self-attention mixes all modality streams within the layer.
The module exposes three public symbols:
MoTAttention— per-modality Q/K/V/O projections, one global SDPA over the concatenated multi-modality sequence.MoTBlock— pre-norm block: per-modality norms + MoTAttention + per-modality FFN. Identity at construction (zero-init residual).mot_warm_start_from_text_stack— copy denseTransformerBlockweights from a source state dict into every per-modality copy of everyMoTBlockin a Transformer. JD/text-only -> MoT warm start.
Functions
|
Copy dense |
Classes
Per-modality Q/K/V/O projections; one global SDPA over all modalities. |
|
Modality-aware transformer block: per-modality norms + MoTAttention + per-modality FFN. |
- class kempnerforge.model.mot.MoTAttention[source]¶
Bases:
ModulePer-modality Q/K/V/O projections; one global SDPA over all modalities.
State-dict layout (per-modality nesting via
nn.ModuleDict):q_proj.{m}.weight # (n_heads * head_dim, dim) k_proj.{m}.weight # (n_kv_heads * head_dim, dim) v_proj.{m}.weight # (n_kv_heads * head_dim, dim) o_proj.{m}.weight # (dim, n_heads * head_dim) q_norm.{m}.weight # (head_dim,) when qk_norm=True k_norm.{m}.weight # (head_dim,) when qk_norm=True
Initialization: each per-modality
o_proj.weightis zero so the operator’s contribution to the residual stream is zero at construction (warm-start identity).Causal mask: a single
is_causal=Trueover the concatenated sequence. With image-then-text concatenation order this matches Chameleon-style autoregressive multimodal: image attends causally among image; text attends to all earlier image and earlier text.- forward(streams, rope, is_causal=True)[source]¶
Run per-modality projections, global SDPA, per-modality output.
- Parameters:
streams (dict[str, torch.Tensor]) – per-modality input. Keys must equal the construction-time modalities. Each value has shape
(batch, seq_m, dim).rope (dict[str, tuple[torch.Tensor, torch.Tensor]]) – per-modality
(cos, sin)RoPE tables. Eachcos[m]/sin[m]has shape(seq_m, head_dim // 2)— counts from position 0 within that modality’s axis.is_causal (bool) – passed through to
F.scaled_dot_product_attention.
- Returns:
per-modality output of shape
(batch, seq_m, dim).- Return type:
- class kempnerforge.model.mot.MoTBlock[source]¶
Bases:
ModuleModality-aware transformer block: per-modality norms + MoTAttention + per-modality FFN.
State-dict layout (per-modality nesting):
attn_norm.{m}.weight # RMSNorm or LayerNorm per modality attn.q_proj.{m}.weight # ... see MoTAttention mlp_norm.{m}.weight mlp.{m}.gate_proj.weight # SwiGLU per modality (or up_proj/down_proj for standard)
Initialization: per-modality
mlp.{m}.down_proj.weightis zero so the FFN contribution is zero at construction. Combined withMoTAttention’s zero-inito_proj, the block is identity at construction (warm-start property: a fresh MoT block is bit-equal to passing the residual through unchanged).MoE branches (when the layer index hits
moe_frequency) skip the zero-init sinceMoEMLPdoes not have a singledown_proj; identity-at-construction is not a hard requirement for MoT — the warm-start helper will overwrite these weights from a source state dict, and from-scratch runs simply train the residual.- __init__(config, modalities, layer_idx)[source]¶
- Parameters:
config (ModelConfig)
layer_idx (int)
- Return type:
None
- forward(streams, rope)[source]¶
- Parameters:
streams (dict[str, torch.Tensor])
rope (dict[str, tuple[torch.Tensor, torch.Tensor]])
- Return type:
- kempnerforge.model.mot.mot_warm_start_from_text_stack(transformer, source_state_dict)[source]¶
Copy dense
TransformerBlockweights into every per-modality copy inside eachMoTBlockintransformer.layers.Use case: warm-start a fresh MoT training run from a JD or text-only checkpoint. The caller loads the source state dict (e.g., via
torch.loador DCP), then calls this helper to translate dense block keys into per-modality copies.Translation, per layer index
iand per modalitym:layers.{i}.attention_norm.weight -> layers.{i}.attn_norm.{m}.weight layers.{i}.attention.q_proj.weight -> layers.{i}.attn.q_proj.{m}.weight layers.{i}.attention.k_proj.weight -> layers.{i}.attn.k_proj.{m}.weight layers.{i}.attention.v_proj.weight -> layers.{i}.attn.v_proj.{m}.weight layers.{i}.attention.o_proj.weight -> layers.{i}.attn.o_proj.{m}.weight # qk_norm only: layers.{i}.attention.q_norm.weight -> layers.{i}.attn.q_norm.{m}.weight layers.{i}.attention.k_norm.weight -> layers.{i}.attn.k_norm.{m}.weight layers.{i}.mlp_norm.weight -> layers.{i}.mlp_norm.{m}.weight layers.{i}.mlp.<proj>.weight -> layers.{i}.mlp.{m}.<proj>.weight for <proj> in {gate_proj, up_proj, down_proj}
Plus the final norm (when present on the target):
norm.weight -> mot_norms.{m}.weight
Source keys may optionally have a
transformer.prefix; both are accepted. MoE FFN branches are skipped (their state dict layout is incompatible with a densemlp.<proj>translation).No-op when
transformerhas noMoTBlocklayers. Idempotent: repeated calls with the same source produce identical state. Shape-checked: raisesValueErroron per-tensor shape mismatch.- Parameters:
transformer (torch.nn.Module)
source_state_dict (Mapping[str, torch.Tensor])
- Return type:
None