Source code for kempnerforge.model.transformer

"""Transformer model for KempnerForge.

Architecture: Llama-style pre-norm transformer.
  Token Embedding → [TransformerBlock × N] → Final Norm → Output Head

Design choices:
  - ModuleDict (not ModuleList) for layers — preserves FQNs for DCP checkpointing.
  - Embedding and output head are optional (can be None for PP middle stages).
  - Forward is a simple loop over blocks — pipeline-parallelism friendly.
"""

from __future__ import annotations

from typing import cast

import torch
import torch.nn as nn

from kempnerforge.config.registry import registry
from kempnerforge.config.schema import ModelConfig
from kempnerforge.config.vlm import CrossAttentionConfig, MoMaConfig, MoTConfig, VLMConfig
from kempnerforge.model.attention import Attention, KVCache
from kempnerforge.model.cross_attention import CrossAttentionBlock
from kempnerforge.model.embedding import OutputHead, TokenEmbedding
from kempnerforge.model.init import init_weights
from kempnerforge.model.mlp import build_mlp
from kempnerforge.model.modality import ModalityContext
from kempnerforge.model.moe import MoEMLP, build_moe
from kempnerforge.model.moma import ExpertChoiceMoE, MoMaBlock, MoMaFFN
from kempnerforge.model.mot import MoTBlock
from kempnerforge.model.norm import build_norm
from kempnerforge.model.position import precompute_rope_frequencies


[docs] class TransformerBlock(nn.Module): """Single transformer block with pre-norm architecture. Structure: norm → attention → residual, norm → mlp → residual """
[docs] def __init__(self, config: ModelConfig, layer_idx: int) -> None: super().__init__() self.layer_idx = layer_idx self.attention_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps) self.attention = Attention( dim=config.dim, n_heads=config.n_heads, n_kv_heads=config.n_kv_heads, # type: ignore[reportArgumentType] head_dim=config.head_dim, qk_norm=config.qk_norm, sdpa_backend=config.sdpa_backend, ) self.mlp_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps) # MoE placement: with moe_frequency=1, all layers are MoE. # With moe_frequency=2, layers 1,3,5... are MoE (layer 0 stays dense). use_moe = config.is_moe and ((layer_idx + 1) % config.moe_frequency == 0) if use_moe: self.mlp = build_moe( dim=config.dim, hidden_dim=config.computed_ffn_hidden_dim, num_experts=config.num_experts, top_k=config.moe_top_k, activation=config.activation, router_type=config.moe_router, shared_experts=config.moe_shared_experts, capacity_factor=config.moe_capacity_factor, gradient_scale=config.moe_gradient_scale, sequence_aux_loss_weight=config.moe_sequence_aux_loss_weight, bias_schedule=config.moe_bias_schedule, packed_experts=config.moe_packed_experts, ) else: self.mlp = build_mlp( dim=config.dim, hidden_dim=config.computed_ffn_hidden_dim, activation=config.activation, )
[docs] def forward( self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, *, kv_cache: KVCache | None = None, doc_ids: torch.Tensor | None = None, ) -> torch.Tensor: # Pre-norm attention with residual x = x + self.attention( self.attention_norm(x), rope_cos, rope_sin, kv_cache=kv_cache, doc_ids=doc_ids ) # Pre-norm MLP with residual x = x + self.mlp(self.mlp_norm(x)) return x
[docs] class Transformer(nn.Module): """Full transformer model built from ModelConfig. Embedding → TransformerBlocks → Norm → Output Head """
[docs] def __init__( self, config: ModelConfig, *, vlm_config: VLMConfig | None = None, num_image_tokens: int = 0, ) -> None: super().__init__() self.config = config # Token embedding (can be None for PP middle stages) self.token_embedding: TokenEmbedding | None = TokenEmbedding(config.vocab_size, config.dim) # MoT branch: build MoTBlocks instead of TransformerBlocks. v1 # enforces equal head counts across modalities (single global # SDPA over the concatenated multi-modality sequence). # MoMa branch: build MoMaBlocks (shared Q/K/V/O attention + # per-modality MoE FFN groups). The branches are mutually # exclusive on layer construction; CA layers are still attached # separately below for ``CrossAttentionConfig``. # num_image_tokens flows in from the vision encoder via the VLM # build path; it is unused for non-MoT arches but kept as a single # constructor arg so the signature is uniform across arches. self._mot_modalities: tuple[str, ...] = () self._mot_n_image: int = 0 self._moma_modalities: tuple[str, ...] = () if isinstance(vlm_config, MoTConfig): text_n_kv_heads = config.n_kv_heads if config.n_kv_heads is not None else config.n_heads img_n_heads, img_n_kv_heads = vlm_config.resolved_image_heads( config.n_heads, text_n_kv_heads ) if img_n_heads != config.n_heads or img_n_kv_heads != text_n_kv_heads: raise ValueError( "MoT v1 requires equal head counts across modalities (single global SDPA): " f"image=({img_n_heads}, {img_n_kv_heads}) vs " f"text=({config.n_heads}, {text_n_kv_heads})" ) self._mot_modalities = vlm_config.mot_modalities self._mot_n_image = num_image_tokens self.layers = nn.ModuleDict( { str(i): MoTBlock(config, modalities=self._mot_modalities, layer_idx=i) for i in range(config.n_layers) } ) elif isinstance(vlm_config, MoMaConfig): self._moma_modalities = vlm_config.moma_modalities experts_per_modality = dict(vlm_config.moma_experts_per_modality) capacity_factor_per_modality = { m: vlm_config.effective_capacity_factor(m) for m in self._moma_modalities } self.layers = nn.ModuleDict( { str(i): MoMaBlock( config, modalities=self._moma_modalities, experts_per_modality=experts_per_modality, capacity_factor_per_modality=capacity_factor_per_modality, gumbel_noise=vlm_config.moma_gumbel_noise, layer_idx=i, ) for i in range(config.n_layers) } ) else: # Transformer blocks — use ModuleDict to preserve FQNs for DCP self.layers = nn.ModuleDict( {str(i): TransformerBlock(config, layer_idx=i) for i in range(config.n_layers)} ) # Cross-Attention layers (only populated when vlm_config is a # CrossAttentionConfig). Empty ModuleDict registers no # state_dict keys, so JD checkpoints load unchanged on builds # where this dict ends up empty. self.cross_attention_layers: nn.ModuleDict = nn.ModuleDict() self._ca_cadence: int = 0 if isinstance(vlm_config, CrossAttentionConfig): self._ca_cadence = vlm_config.cross_attention_every_n_layers n_h, n_kv = vlm_config.resolved_heads(config.n_heads) num_ca_blocks = config.n_layers // self._ca_cadence for k in range(num_ca_blocks): self.cross_attention_layers[str(k)] = CrossAttentionBlock( dim=config.dim, n_heads=n_h, n_kv_heads=n_kv, ffn_hidden_dim=config.computed_ffn_hidden_dim, norm_type=config.norm_type, activation=config.activation, ) # Final normalization. Used by the non-MoT path. MoT uses # per-modality ``mot_norms`` instead; ``self.norm`` is built # regardless so cross-arch DCP loads can carry ``norm.weight`` # uniformly. When MoT is active, ``self.norm`` is unused in # forward and is frozen so it does not appear as an orphan # trainable parameter (its grad would always be ``None``). self.norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps) self.mot_norms: nn.ModuleDict = nn.ModuleDict() if self._mot_modalities: self.mot_norms = nn.ModuleDict( { m: build_norm(config.norm_type, config.dim, eps=config.norm_eps) for m in self._mot_modalities } ) for p in self.norm.parameters(): p.requires_grad_(False) # Output head (can be None for PP non-final stages) self.output_head: OutputHead | None = OutputHead(config.dim, config.vocab_size) # Weight tying can_tie = self.token_embedding is not None and self.output_head is not None if config.tie_embeddings and can_tie: self.output_head.tie_weights(self.token_embedding) # Precompute RoPE cos/sin tables — stored as plain attributes (not buffers) # so model.to(bf16) doesn't cast them from float32. # Skip when on meta device (no data); call init_weights_and_freqs() later. self._rope_cos = None self._rope_sin = None if not any(p.is_meta for p in self.parameters()): self._rope_cos, self._rope_sin = precompute_rope_frequencies( head_dim=config.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta, ) init_weights(self, config)
[docs] def init_weights_and_freqs(self) -> None: """Initialize weights and RoPE frequencies after meta-device materialization. Called after ``model.to_empty(device=...)`` to fill in parameter values and compute RoPE frequency table. Safe to call on already-initialized models (skips if freqs are already computed). """ if self._rope_cos is None: self._rope_cos, self._rope_sin = precompute_rope_frequencies( head_dim=self.config.head_dim, max_seq_len=self.config.max_seq_len, theta=self.config.rope_theta, ) init_weights(self, self.config)
[docs] def set_moe_step(self, step: int, max_steps: int) -> None: """Set training step on all MoE routers for adaptive bias scheduling.""" from kempnerforge.model.router import SigmoidTopKRouter for layer in self.layers.values(): if isinstance(layer.mlp, MoEMLP) and isinstance(layer.mlp.router, SigmoidTopKRouter): layer.mlp.router.set_step(step, max_steps)
[docs] def get_moe_aux_loss(self) -> torch.Tensor: """Collect auxiliary losses from all MoE layers. Returns 0 if dense.""" total = torch.tensor(0.0, device=next(self.parameters()).device) for layer in self.layers.values(): if isinstance(layer.mlp, MoEMLP): total = total + layer.mlp.aux_loss return total
[docs] def get_expert_counts(self) -> dict[int, torch.Tensor]: """Collect per-layer expert utilization for flat MoE layers. Returns ``{layer_idx: (num_experts,) tensor}`` for layers whose MLP is a ``MoEMLP`` (the standard, single-pool MoE). Returns ``{}`` for dense models and for MoMa: MoMa's per-modality groups have a different shape (per modality, per expert) and surface through ``get_moma_expert_counts`` instead. """ counts = {} for name, layer in self.layers.items(): if isinstance(layer.mlp, MoEMLP): counts[int(name)] = layer.mlp.expert_counts return counts
[docs] def get_moma_expert_counts(self) -> dict[int, dict[str, torch.Tensor]]: """Collect per-layer, per-modality expert utilization for MoMa layers. Returns ``{layer_idx: {modality: (num_experts_for_modality,) tensor}}`` for every layer whose MLP is a ``MoMaFFN``; returns ``{}`` otherwise. Each modality's expert count tensor reflects the most recent forward through that layer's expert-choice router (paper Figure 5-style utilization). Counts on a fresh model (no forward yet) are the router's init zeros. """ counts: dict[int, dict[str, torch.Tensor]] = {} for name, layer in self.layers.items(): if isinstance(layer.mlp, MoMaFFN): # nn.ModuleDict.__getitem__ returns Module; cast back to the # concrete expert-group type so pyright sees the # ``expert_counts`` Tensor rather than ``Tensor | Module``. # The cast is safe because ``MoMaFFN.__init__`` only ever # writes ``ExpertChoiceMoE`` values into ``experts``. counts[int(name)] = { m: cast(ExpertChoiceMoE, layer.mlp.experts[m]).expert_counts for m in layer.mlp.modalities } return counts
[docs] def forward( self, tokens: torch.Tensor | None = None, *, modality: ModalityContext | None = None, kv_caches: list[KVCache] | None = None, doc_ids: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass. Exactly one of ``tokens`` or ``modality.inputs_embeds`` must be provided. Modality-injection routes (``prefix_embeds``, ``output_slice``, ``image_features``, ``image_mask``, ``modality_ids``) are grouped on the optional ``ModalityContext`` arg; see ``kempnerforge/model/modality.py`` for the full intra-context invariant table. Args: tokens: Integer token ids, shape ``(batch, seq_len)``. modality: Optional ``ModalityContext`` bundling pre-embedded inputs, prefix embeds, output slicing, image features, and modality routing tags for VLM arches. ``None`` is the pure text-only forward. kv_caches: Optional list of KVCache (one per layer) for generation. When provided, RoPE positions are offset by the current cache fill level. Cross-arg invariant: ``kv_caches`` forbids ``modality.prefix_embeds``, ``modality.output_slice``, ``modality.image_features``, and ``modality.modality_ids`` (all training-only). doc_ids: Optional per-token document IDs for packed sequences, shape ``(batch, seq_len)``. Enables block-diagonal causal attention that isolates documents within packed sequences. Returns: Logits tensor of shape ``(batch, out_seq_len, vocab_size)`` where ``out_seq_len == seq_len`` normally or the sliced length when ``modality.output_slice`` is set. """ inputs_embeds = modality.inputs_embeds if modality is not None else None prefix_embeds = modality.prefix_embeds if modality is not None else None output_slice = modality.output_slice if modality is not None else None image_features = modality.image_features if modality is not None else None image_mask = modality.image_mask if modality is not None else None modality_ids = modality.modality_ids if modality is not None else None if (tokens is None) == (inputs_embeds is None): raise ValueError( "Transformer.forward requires exactly one of tokens or modality.inputs_embeds" ) if kv_caches is not None: if output_slice is not None: raise ValueError( "modality.output_slice is training-only; cannot be combined with kv_caches" ) if prefix_embeds is not None: # If we ever allowed both, the RoPE slice below would start at # start_pos (the cache's text fill level) but seq_len would # include the prefix — positions would double-count the prefix # at every decode step. Training-only. raise ValueError( "modality.prefix_embeds is training-only; cannot be combined with kv_caches" ) if image_features is not None: raise ValueError( "modality.image_features is training-only; cannot be combined with kv_caches" ) if modality_ids is not None: # MoT routes per-token through per-modality projections via # modality_ids; KV-cache decode has no equivalent semantics # (cache positions are pre-routed). Training-only for v1. raise ValueError( "modality.modality_ids is training-only; cannot be combined with kv_caches" ) # modality_ids dtype/shape checks. Shape against the residual seq_len # is checked downstream (when the residual is built); dtype is # checked here so the error fires before any compute. if modality_ids is not None and modality_ids.dtype != torch.long: raise ValueError( f"modality.modality_ids.dtype must be torch.long, got {modality_ids.dtype}" ) h = ( self.token_embedding(tokens) # type: ignore[reportOptionalCall] if tokens is not None else inputs_embeds ) assert h is not None # narrowed by the XOR check above if prefix_embeds is not None: # Cast to the text-embedding dtype so the concat does not promote. if prefix_embeds.dtype != h.dtype: prefix_embeds = prefix_embeds.to(h.dtype) h = torch.cat([prefix_embeds, h], dim=1) seq_len = h.shape[1] # Determine position offset from KV cache fill level start_pos = kv_caches[0].seq_len if kv_caches is not None else 0 # Slice RoPE frequencies for current positions (device transfer cached) if self._rope_cos.device != h.device: # type: ignore[reportOptionalMemberAccess] self._rope_cos = self._rope_cos.to(h.device) # type: ignore[reportOptionalMemberAccess] self._rope_sin = self._rope_sin.to(h.device) # type: ignore[reportOptionalMemberAccess] cos = self._rope_cos[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript] sin = self._rope_sin[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript] # MoMa path: single residual stream + shared SDPA + per-modality # MoE FFN groups. modality_ids tags every position and the # ``MoMaFFN`` uses these tags to dispatch tokens to per-modality # expert groups (level-1 deterministic routing); within each # group, expert-choice + Sigmoid routing picks experts # (level-2 learned routing). EC routing is non-causal, so # ``kv_caches`` is rejected upstream (training-only in v1). if self._moma_modalities: if modality_ids is None: raise ValueError( "MoMa model requires modality.modality_ids (got None). Build the " "ModalityContext via MoMaStrategy or set modality_ids explicitly." ) if modality_ids.shape != h.shape[:2]: raise ValueError( f"modality.modality_ids shape {tuple(modality_ids.shape)} does not " f"match residual shape {tuple(h.shape[:2])}" ) for layer in self.layers.values(): h = layer(h, cos, sin, modality_ids, doc_ids=doc_ids) h = self.norm(h) # MoT path: position-based image-then-text split, per-modality # streams through the MoTBlock stack, single global SDPA per # layer. modality_ids is required (presence + shape checked # against the residual). v1 uses position-based routing; the # tags are validated for shape but not value-matched against # positions, so a future per-token scatter/gather can land # without changing the public interface. elif self._mot_modalities: if modality_ids is None: raise ValueError( "MoT model requires modality.modality_ids (got None). Build the " "ModalityContext via MoTStrategy or set modality_ids explicitly." ) if modality_ids.shape != h.shape[:2]: raise ValueError( f"modality.modality_ids shape {tuple(modality_ids.shape)} does not " f"match residual shape {tuple(h.shape[:2])}" ) n_image = self._mot_n_image t_image = n_image t_text = h.shape[1] - n_image streams: dict[str, torch.Tensor] = { "image": h[:, :t_image, :], "text": h[:, t_image:, :], } # Per-modality RoPE: each modality counts positions from 0 # within its own axis. Image and text share the same RoPE # table since head_dim is shared. rope = { "image": (cos[:t_image], sin[:t_image]), "text": (cos[:t_text], sin[:t_text]), } for layer in self.layers.values(): streams = layer(streams, rope) streams = {m: self.mot_norms[m](streams[m]) for m in self._mot_modalities} # Re-concat in image-then-text order to match the residual # layout the rest of forward expects (output_slice + head). h = torch.cat([streams["image"], streams["text"]], dim=1) else: # Transformer blocks. When the model has cross-attention layers # (CrossAttentionConfig + nonzero cadence), a CrossAttentionBlock # fires after text block index i iff (i+1) % _ca_cadence == 0. # _ca_cadence == 0 (text-only / Joint-Decoder) makes the inner # branch dead, so the JD path stays bit-equal to today's. ca_iter = iter(self.cross_attention_layers.values()) if self._ca_cadence else None for i, layer in enumerate(self.layers.values()): cache = kv_caches[i] if kv_caches is not None else None h = layer(h, cos, sin, kv_cache=cache, doc_ids=doc_ids) if ca_iter is not None and (i + 1) % self._ca_cadence == 0: ca = next(ca_iter, None) if ca is not None: if image_features is None: raise ValueError( "Cross-Attention block fired but modality.image_features is None. " "Cross-Attention models require image_features in the " "ModalityContext." ) if image_features.dtype != h.dtype: image_features = image_features.to(h.dtype) h = ca(h, image_features, image_mask) # Final norm h = self.norm(h) # Optional slice before the output head (training-only kwarg) if output_slice is not None: h = h[:, output_slice, :] # Output projection if self.output_head is not None: h = self.output_head(h) return h
@registry.register_model("transformer") def _build_transformer(model_config: ModelConfig) -> Transformer: return Transformer(model_config)