Source code for kempnerforge.model.mot

"""Mixture-of-Transformers (MoT) operator and block.

Implements Algorithm 1 of Liang et al. (2024) "Mixture-of-Transformers"
and Figure 1c of the multimodal_paper. At every layer, every modality
has its own dense Q/K/V/O projections and a dedicated FFN; a single
global self-attention mixes all modality streams within the layer.

The module exposes three public symbols:

- ``MoTAttention`` — per-modality Q/K/V/O projections, one global SDPA
  over the concatenated multi-modality sequence.
- ``MoTBlock`` — pre-norm block: per-modality norms + MoTAttention +
  per-modality FFN. Identity at construction (zero-init residual).
- ``mot_warm_start_from_text_stack`` — copy dense ``TransformerBlock``
  weights from a source state dict into every per-modality copy of
  every ``MoTBlock`` in a Transformer. JD/text-only -> MoT warm start.
"""

from __future__ import annotations

from collections.abc import Mapping
from typing import cast

import torch
import torch.nn as nn
import torch.nn.functional as F

from kempnerforge.config.schema import ModelConfig
from kempnerforge.model.mlp import build_mlp
from kempnerforge.model.moe import MoEMLP, build_moe
from kempnerforge.model.norm import RMSNorm, build_norm
from kempnerforge.model.position import apply_rope


[docs] class MoTAttention(nn.Module): """Per-modality Q/K/V/O projections; one global SDPA over all modalities. State-dict layout (per-modality nesting via ``nn.ModuleDict``): .. code:: q_proj.{m}.weight # (n_heads * head_dim, dim) k_proj.{m}.weight # (n_kv_heads * head_dim, dim) v_proj.{m}.weight # (n_kv_heads * head_dim, dim) o_proj.{m}.weight # (dim, n_heads * head_dim) q_norm.{m}.weight # (head_dim,) when qk_norm=True k_norm.{m}.weight # (head_dim,) when qk_norm=True Initialization: each per-modality ``o_proj.weight`` is zero so the operator's contribution to the residual stream is zero at construction (warm-start identity). Causal mask: a single ``is_causal=True`` over the concatenated sequence. With image-then-text concatenation order this matches Chameleon-style autoregressive multimodal: image attends causally among image; text attends to all earlier image and earlier text. """
[docs] def __init__( self, dim: int, n_heads: int, n_kv_heads: int, modalities: tuple[str, ...], head_dim: int | None = None, qk_norm: bool = False, ) -> None: super().__init__() if not modalities: raise ValueError("MoTAttention requires at least one modality") if n_kv_heads <= 0 or n_heads % n_kv_heads != 0: raise ValueError( f"MoTAttention: n_heads={n_heads} must be a positive multiple of " f"n_kv_heads={n_kv_heads}" ) self.modalities = tuple(modalities) self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim or (dim // n_heads) self.n_rep = n_heads // n_kv_heads self.q_proj = nn.ModuleDict( {m: nn.Linear(dim, n_heads * self.head_dim, bias=False) for m in self.modalities} ) self.k_proj = nn.ModuleDict( {m: nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) for m in self.modalities} ) self.v_proj = nn.ModuleDict( {m: nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) for m in self.modalities} ) self.o_proj = nn.ModuleDict( {m: nn.Linear(n_heads * self.head_dim, dim, bias=False) for m in self.modalities} ) for m in self.modalities: nn.init.zeros_(self.o_proj[m].weight) # type: ignore[reportArgumentType] if qk_norm: self.q_norm: nn.ModuleDict | None = nn.ModuleDict( {m: RMSNorm(self.head_dim) for m in self.modalities} ) self.k_norm: nn.ModuleDict | None = nn.ModuleDict( {m: RMSNorm(self.head_dim) for m in self.modalities} ) else: self.q_norm = None self.k_norm = None
[docs] def forward( self, streams: dict[str, torch.Tensor], rope: dict[str, tuple[torch.Tensor, torch.Tensor]], is_causal: bool = True, ) -> dict[str, torch.Tensor]: """Run per-modality projections, global SDPA, per-modality output. Args: streams: per-modality input. Keys must equal the construction-time modalities. Each value has shape ``(batch, seq_m, dim)``. rope: per-modality ``(cos, sin)`` RoPE tables. Each ``cos[m]`` / ``sin[m]`` has shape ``(seq_m, head_dim // 2)`` — counts from position 0 within that modality's axis. is_causal: passed through to ``F.scaled_dot_product_attention``. Returns: per-modality output of shape ``(batch, seq_m, dim)``. """ if set(streams.keys()) != set(self.modalities): raise ValueError( f"MoTAttention.forward: streams keys {sorted(streams.keys())} " f"do not match construction-time modalities {sorted(self.modalities)}" ) batch = next(iter(streams.values())).shape[0] qs: list[torch.Tensor] = [] ks: list[torch.Tensor] = [] vs: list[torch.Tensor] = [] lengths: dict[str, int] = {} for m in self.modalities: x_m = streams[m] t_m = x_m.shape[1] lengths[m] = t_m q_m = self.q_proj[m](x_m).view(batch, t_m, -1, self.head_dim) k_m = self.k_proj[m](x_m).view(batch, t_m, -1, self.head_dim) v_m = self.v_proj[m](x_m).view(batch, t_m, -1, self.head_dim) if self.q_norm is not None: q_m = self.q_norm[m](q_m) k_m = self.k_norm[m](k_m) # type: ignore[reportOptionalSubscript,reportOptionalCall] # Transpose to (batch, heads, seq, head_dim) for RoPE + SDPA. q_m = q_m.transpose(1, 2) k_m = k_m.transpose(1, 2) v_m = v_m.transpose(1, 2) cos_m, sin_m = rope[m] q_m = apply_rope(q_m, cos_m, sin_m) k_m = apply_rope(k_m, cos_m, sin_m) qs.append(q_m) ks.append(k_m) vs.append(v_m) # Concat along the seq dim of (B, heads, T, head_dim). q = torch.cat(qs, dim=2) k = torch.cat(ks, dim=2) v = torch.cat(vs, dim=2) if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) # (B, n_heads, total_seq, head_dim) -> (B, total_seq, n_heads, head_dim) out = out.transpose(1, 2).contiguous() out_streams: dict[str, torch.Tensor] = {} offset = 0 for m in self.modalities: t_m = lengths[m] o_m = out[:, offset : offset + t_m, :, :].reshape(batch, t_m, -1) out_streams[m] = self.o_proj[m](o_m) offset += t_m return out_streams
[docs] class MoTBlock(nn.Module): """Modality-aware transformer block: per-modality norms + MoTAttention + per-modality FFN. State-dict layout (per-modality nesting): .. code:: attn_norm.{m}.weight # RMSNorm or LayerNorm per modality attn.q_proj.{m}.weight # ... see MoTAttention mlp_norm.{m}.weight mlp.{m}.gate_proj.weight # SwiGLU per modality (or up_proj/down_proj for standard) Initialization: per-modality ``mlp.{m}.down_proj.weight`` is zero so the FFN contribution is zero at construction. Combined with ``MoTAttention``'s zero-init ``o_proj``, the block is identity at construction (warm-start property: a fresh MoT block is bit-equal to passing the residual through unchanged). MoE branches (when the layer index hits ``moe_frequency``) skip the zero-init since ``MoEMLP`` does not have a single ``down_proj``; identity-at-construction is not a hard requirement for MoT — the warm-start helper will overwrite these weights from a source state dict, and from-scratch runs simply train the residual. """
[docs] def __init__( self, config: ModelConfig, modalities: tuple[str, ...], layer_idx: int, ) -> None: super().__init__() if not modalities: raise ValueError("MoTBlock requires at least one modality") self.layer_idx = layer_idx self.modalities = tuple(modalities) self.attn_norm = nn.ModuleDict( { m: build_norm(config.norm_type, config.dim, eps=config.norm_eps) for m in self.modalities } ) self.attn = MoTAttention( dim=config.dim, n_heads=config.n_heads, n_kv_heads=config.n_kv_heads, # type: ignore[reportArgumentType] modalities=self.modalities, head_dim=config.head_dim, qk_norm=config.qk_norm, ) self.mlp_norm = nn.ModuleDict( { m: build_norm(config.norm_type, config.dim, eps=config.norm_eps) for m in self.modalities } ) use_moe = config.is_moe and ((layer_idx + 1) % config.moe_frequency == 0) if use_moe: self.mlp = nn.ModuleDict( { m: build_moe( dim=config.dim, hidden_dim=config.computed_ffn_hidden_dim, num_experts=config.num_experts, top_k=config.moe_top_k, activation=config.activation, router_type=config.moe_router, shared_experts=config.moe_shared_experts, capacity_factor=config.moe_capacity_factor, gradient_scale=config.moe_gradient_scale, sequence_aux_loss_weight=config.moe_sequence_aux_loss_weight, bias_schedule=config.moe_bias_schedule, packed_experts=config.moe_packed_experts, ) for m in self.modalities } ) else: self.mlp = nn.ModuleDict( { m: build_mlp( dim=config.dim, hidden_dim=config.computed_ffn_hidden_dim, activation=config.activation, ) for m in self.modalities } ) # Zero-init per-modality down_proj on dense FFNs for warm-start identity. # MoE FFNs have no single down_proj — they keep registry-default init. for m in self.modalities: mlp_m = self.mlp[m] if isinstance(mlp_m, MoEMLP): continue if hasattr(mlp_m, "down_proj"): nn.init.zeros_(mlp_m.down_proj.weight) # type: ignore[union-attr]
[docs] def forward( self, streams: dict[str, torch.Tensor], rope: dict[str, tuple[torch.Tensor, torch.Tensor]], ) -> dict[str, torch.Tensor]: if set(streams.keys()) != set(self.modalities): raise ValueError( f"MoTBlock.forward: streams keys {sorted(streams.keys())} " f"do not match construction-time modalities {sorted(self.modalities)}" ) normed_attn = {m: self.attn_norm[m](streams[m]) for m in self.modalities} attn_out = self.attn(normed_attn, rope, is_causal=True) post_attn = {m: streams[m] + attn_out[m] for m in self.modalities} normed_mlp = {m: self.mlp_norm[m](post_attn[m]) for m in self.modalities} mlp_out = {m: self.mlp[m](normed_mlp[m]) for m in self.modalities} return {m: post_attn[m] + mlp_out[m] for m in self.modalities}
# --------------------------------------------------------------------------- # Warm-start helper # --------------------------------------------------------------------------- def _copy_weight( target_module: nn.Module, src_tensor: torch.Tensor, src_key: str, modality: str, ) -> None: """Copy ``src_tensor`` into ``target_module.weight`` in place, shape-checked. Handles both plain ``torch.Tensor`` targets and FSDP2 ``DTensor`` targets: when the target is a ``DTensor``, ``src_tensor`` is sharded to match the target's mesh + placements before the in-place copy. """ target = cast(torch.Tensor, target_module.weight) if target.shape != src_tensor.shape: raise ValueError( f"mot_warm_start: shape mismatch for source key '{src_key}' -> " f"modality '{modality}': source {tuple(src_tensor.shape)} vs " f"target {tuple(target.shape)}" ) src_cast = src_tensor.to(dtype=target.dtype, device=target.device) # DTensor path: shard the source to the target's placement, then copy # local-shard-to-local-shard. ``DTensor.copy_`` does not accept a plain # tensor under FSDP2. if hasattr(target, "_local_tensor") and hasattr(target, "device_mesh"): from torch.distributed.tensor import distribute_tensor # noqa: PLC0415 src_d = distribute_tensor(src_cast, target.device_mesh, target.placements) # type: ignore[attr-defined] target._local_tensor.copy_(src_d._local_tensor) # type: ignore[attr-defined] else: target.data.copy_(src_cast)
[docs] def mot_warm_start_from_text_stack( transformer: nn.Module, source_state_dict: Mapping[str, torch.Tensor], ) -> None: """Copy dense ``TransformerBlock`` weights into every per-modality copy inside each ``MoTBlock`` in ``transformer.layers``. Use case: warm-start a fresh MoT training run from a JD or text-only checkpoint. The caller loads the source state dict (e.g., via ``torch.load`` or DCP), then calls this helper to translate dense block keys into per-modality copies. Translation, per layer index ``i`` and per modality ``m``: .. code:: layers.{i}.attention_norm.weight -> layers.{i}.attn_norm.{m}.weight layers.{i}.attention.q_proj.weight -> layers.{i}.attn.q_proj.{m}.weight layers.{i}.attention.k_proj.weight -> layers.{i}.attn.k_proj.{m}.weight layers.{i}.attention.v_proj.weight -> layers.{i}.attn.v_proj.{m}.weight layers.{i}.attention.o_proj.weight -> layers.{i}.attn.o_proj.{m}.weight # qk_norm only: layers.{i}.attention.q_norm.weight -> layers.{i}.attn.q_norm.{m}.weight layers.{i}.attention.k_norm.weight -> layers.{i}.attn.k_norm.{m}.weight layers.{i}.mlp_norm.weight -> layers.{i}.mlp_norm.{m}.weight layers.{i}.mlp.<proj>.weight -> layers.{i}.mlp.{m}.<proj>.weight for <proj> in {gate_proj, up_proj, down_proj} Plus the final norm (when present on the target): .. code:: norm.weight -> mot_norms.{m}.weight Source keys may optionally have a ``transformer.`` prefix; both are accepted. MoE FFN branches are skipped (their state dict layout is incompatible with a dense ``mlp.<proj>`` translation). No-op when ``transformer`` has no ``MoTBlock`` layers. Idempotent: repeated calls with the same source produce identical state. Shape-checked: raises ``ValueError`` on per-tensor shape mismatch. """ layers = getattr(transformer, "layers", None) if layers is None: return mot_layers = {idx: layer for idx, layer in layers.items() if isinstance(layer, MoTBlock)} if not mot_layers: return state: dict[str, torch.Tensor] = {} for k, v in source_state_dict.items(): canonical = k[len("transformer.") :] if k.startswith("transformer.") else k state[canonical] = v with torch.no_grad(): for idx, layer in mot_layers.items(): modalities = layer.modalities attn_translations: list[tuple[str, nn.ModuleDict]] = [ (f"layers.{idx}.attention_norm.weight", layer.attn_norm), (f"layers.{idx}.attention.q_proj.weight", layer.attn.q_proj), (f"layers.{idx}.attention.k_proj.weight", layer.attn.k_proj), (f"layers.{idx}.attention.v_proj.weight", layer.attn.v_proj), (f"layers.{idx}.attention.o_proj.weight", layer.attn.o_proj), (f"layers.{idx}.mlp_norm.weight", layer.mlp_norm), ] if layer.attn.q_norm is not None: attn_translations.append( (f"layers.{idx}.attention.q_norm.weight", layer.attn.q_norm) ) attn_translations.append( (f"layers.{idx}.attention.k_norm.weight", layer.attn.k_norm) # type: ignore[arg-type] ) for src_key, target_dict in attn_translations: if src_key not in state: continue src_tensor = state[src_key] for m in modalities: _copy_weight(target_dict[m], src_tensor, src_key, m) for proj in ("gate_proj", "up_proj", "down_proj"): src_key = f"layers.{idx}.mlp.{proj}.weight" if src_key not in state: continue src_tensor = state[src_key] for m in modalities: mlp_m = layer.mlp[m] if isinstance(mlp_m, MoEMLP) or not hasattr(mlp_m, proj): continue _copy_weight(getattr(mlp_m, proj), src_tensor, src_key, m) mot_norms = getattr(transformer, "mot_norms", None) if mot_norms is not None and len(mot_norms) > 0 and "norm.weight" in state: src_tensor = state["norm.weight"] for m in mot_norms: _copy_weight(mot_norms[m], src_tensor, "norm.weight", m)