Source code for kempnerforge.model.transformer
"""Transformer model for KempnerForge.
Architecture: Llama-style pre-norm transformer.
Token Embedding → [TransformerBlock × N] → Final Norm → Output Head
Design choices:
- ModuleDict (not ModuleList) for layers — preserves FQNs for DCP checkpointing.
- Embedding and output head are optional (can be None for PP middle stages).
- Forward is a simple loop over blocks — pipeline-parallelism friendly.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from kempnerforge.config.registry import registry
from kempnerforge.config.schema import ModelConfig
from kempnerforge.config.vlm import CrossAttentionConfig, MoTConfig, VLMConfig
from kempnerforge.model.attention import Attention, KVCache
from kempnerforge.model.cross_attention import CrossAttentionBlock
from kempnerforge.model.embedding import OutputHead, TokenEmbedding
from kempnerforge.model.init import init_weights
from kempnerforge.model.mlp import build_mlp
from kempnerforge.model.modality import ModalityContext
from kempnerforge.model.moe import MoEMLP, build_moe
from kempnerforge.model.mot import MoTBlock
from kempnerforge.model.norm import build_norm
from kempnerforge.model.position import precompute_rope_frequencies
[docs]
class TransformerBlock(nn.Module):
"""Single transformer block with pre-norm architecture.
Structure: norm → attention → residual, norm → mlp → residual
"""
[docs]
def __init__(self, config: ModelConfig, layer_idx: int) -> None:
super().__init__()
self.layer_idx = layer_idx
self.attention_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps)
self.attention = Attention(
dim=config.dim,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads, # type: ignore[reportArgumentType]
head_dim=config.head_dim,
qk_norm=config.qk_norm,
sdpa_backend=config.sdpa_backend,
)
self.mlp_norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps)
# MoE placement: with moe_frequency=1, all layers are MoE.
# With moe_frequency=2, layers 1,3,5... are MoE (layer 0 stays dense).
use_moe = config.is_moe and ((layer_idx + 1) % config.moe_frequency == 0)
if use_moe:
self.mlp = build_moe(
dim=config.dim,
hidden_dim=config.computed_ffn_hidden_dim,
num_experts=config.num_experts,
top_k=config.moe_top_k,
activation=config.activation,
router_type=config.moe_router,
shared_experts=config.moe_shared_experts,
capacity_factor=config.moe_capacity_factor,
gradient_scale=config.moe_gradient_scale,
sequence_aux_loss_weight=config.moe_sequence_aux_loss_weight,
bias_schedule=config.moe_bias_schedule,
packed_experts=config.moe_packed_experts,
)
else:
self.mlp = build_mlp(
dim=config.dim,
hidden_dim=config.computed_ffn_hidden_dim,
activation=config.activation,
)
[docs]
def forward(
self,
x: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
*,
kv_cache: KVCache | None = None,
doc_ids: torch.Tensor | None = None,
) -> torch.Tensor:
# Pre-norm attention with residual
x = x + self.attention(
self.attention_norm(x), rope_cos, rope_sin, kv_cache=kv_cache, doc_ids=doc_ids
)
# Pre-norm MLP with residual
x = x + self.mlp(self.mlp_norm(x))
return x
[docs]
class Transformer(nn.Module):
"""Full transformer model built from ModelConfig.
Embedding → TransformerBlocks → Norm → Output Head
"""
[docs]
def __init__(
self,
config: ModelConfig,
*,
vlm_config: VLMConfig | None = None,
num_image_tokens: int = 0,
) -> None:
super().__init__()
self.config = config
# Token embedding (can be None for PP middle stages)
self.token_embedding: TokenEmbedding | None = TokenEmbedding(config.vocab_size, config.dim)
# MoT branch: build MoTBlocks instead of TransformerBlocks. v1
# enforces equal head counts across modalities (single global
# SDPA over the concatenated multi-modality sequence).
# num_image_tokens flows in from the vision encoder via the VLM
# build path; it is unused for non-MoT arches but kept as a single
# constructor arg so the signature is uniform across arches.
self._mot_modalities: tuple[str, ...] = ()
self._mot_n_image: int = 0
if isinstance(vlm_config, MoTConfig):
text_n_kv_heads = config.n_kv_heads if config.n_kv_heads is not None else config.n_heads
img_n_heads, img_n_kv_heads = vlm_config.resolved_image_heads(
config.n_heads, text_n_kv_heads
)
if img_n_heads != config.n_heads or img_n_kv_heads != text_n_kv_heads:
raise ValueError(
"MoT v1 requires equal head counts across modalities (single global SDPA): "
f"image=({img_n_heads}, {img_n_kv_heads}) vs "
f"text=({config.n_heads}, {text_n_kv_heads})"
)
self._mot_modalities = vlm_config.mot_modalities
self._mot_n_image = num_image_tokens
self.layers = nn.ModuleDict(
{
str(i): MoTBlock(config, modalities=self._mot_modalities, layer_idx=i)
for i in range(config.n_layers)
}
)
else:
# Transformer blocks — use ModuleDict to preserve FQNs for DCP
self.layers = nn.ModuleDict(
{str(i): TransformerBlock(config, layer_idx=i) for i in range(config.n_layers)}
)
# Cross-Attention layers (only populated when vlm_config is a
# CrossAttentionConfig). Empty ModuleDict registers no
# state_dict keys, so JD checkpoints load unchanged on builds
# where this dict ends up empty.
self.cross_attention_layers: nn.ModuleDict = nn.ModuleDict()
self._ca_cadence: int = 0
if isinstance(vlm_config, CrossAttentionConfig):
self._ca_cadence = vlm_config.cross_attention_every_n_layers
n_h, n_kv = vlm_config.resolved_heads(config.n_heads)
num_ca_blocks = config.n_layers // self._ca_cadence
for k in range(num_ca_blocks):
self.cross_attention_layers[str(k)] = CrossAttentionBlock(
dim=config.dim,
n_heads=n_h,
n_kv_heads=n_kv,
ffn_hidden_dim=config.computed_ffn_hidden_dim,
norm_type=config.norm_type,
activation=config.activation,
)
# Final normalization. Used by the non-MoT path. MoT uses
# per-modality ``mot_norms`` instead; ``self.norm`` is built
# regardless so cross-arch DCP loads can carry ``norm.weight``
# uniformly. When MoT is active, ``self.norm`` is unused in
# forward and is frozen so it does not appear as an orphan
# trainable parameter (its grad would always be ``None``).
self.norm = build_norm(config.norm_type, config.dim, eps=config.norm_eps)
self.mot_norms: nn.ModuleDict = nn.ModuleDict()
if self._mot_modalities:
self.mot_norms = nn.ModuleDict(
{
m: build_norm(config.norm_type, config.dim, eps=config.norm_eps)
for m in self._mot_modalities
}
)
for p in self.norm.parameters():
p.requires_grad_(False)
# Output head (can be None for PP non-final stages)
self.output_head: OutputHead | None = OutputHead(config.dim, config.vocab_size)
# Weight tying
can_tie = self.token_embedding is not None and self.output_head is not None
if config.tie_embeddings and can_tie:
self.output_head.tie_weights(self.token_embedding)
# Precompute RoPE cos/sin tables — stored as plain attributes (not buffers)
# so model.to(bf16) doesn't cast them from float32.
# Skip when on meta device (no data); call init_weights_and_freqs() later.
self._rope_cos = None
self._rope_sin = None
if not any(p.is_meta for p in self.parameters()):
self._rope_cos, self._rope_sin = precompute_rope_frequencies(
head_dim=config.head_dim,
max_seq_len=config.max_seq_len,
theta=config.rope_theta,
)
init_weights(self, config)
[docs]
def init_weights_and_freqs(self) -> None:
"""Initialize weights and RoPE frequencies after meta-device materialization.
Called after ``model.to_empty(device=...)`` to fill in parameter values
and compute RoPE frequency table. Safe to call on already-initialized models
(skips if freqs are already computed).
"""
if self._rope_cos is None:
self._rope_cos, self._rope_sin = precompute_rope_frequencies(
head_dim=self.config.head_dim,
max_seq_len=self.config.max_seq_len,
theta=self.config.rope_theta,
)
init_weights(self, self.config)
[docs]
def set_moe_step(self, step: int, max_steps: int) -> None:
"""Set training step on all MoE routers for adaptive bias scheduling."""
from kempnerforge.model.router import SigmoidTopKRouter
for layer in self.layers.values():
if isinstance(layer.mlp, MoEMLP) and isinstance(layer.mlp.router, SigmoidTopKRouter):
layer.mlp.router.set_step(step, max_steps)
[docs]
def get_moe_aux_loss(self) -> torch.Tensor:
"""Collect auxiliary losses from all MoE layers. Returns 0 if dense."""
total = torch.tensor(0.0, device=next(self.parameters()).device)
for layer in self.layers.values():
if isinstance(layer.mlp, MoEMLP):
total = total + layer.mlp.aux_loss
return total
[docs]
def get_expert_counts(self) -> dict[int, torch.Tensor]:
"""Collect per-layer expert utilization. Returns {} if dense."""
counts = {}
for name, layer in self.layers.items():
if isinstance(layer.mlp, MoEMLP):
counts[int(name)] = layer.mlp.expert_counts
return counts
[docs]
def forward(
self,
tokens: torch.Tensor | None = None,
*,
modality: ModalityContext | None = None,
kv_caches: list[KVCache] | None = None,
doc_ids: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass.
Exactly one of ``tokens`` or ``modality.inputs_embeds`` must be
provided. Modality-injection routes (``prefix_embeds``,
``output_slice``, ``image_features``, ``image_mask``,
``modality_ids``) are grouped on the optional
``ModalityContext`` arg; see ``kempnerforge/model/modality.py``
for the full intra-context invariant table.
Args:
tokens: Integer token ids, shape ``(batch, seq_len)``.
modality: Optional ``ModalityContext`` bundling pre-embedded
inputs, prefix embeds, output slicing, image features,
and modality routing tags for VLM arches. ``None`` is the
pure text-only forward.
kv_caches: Optional list of KVCache (one per layer) for
generation. When provided, RoPE positions are offset by
the current cache fill level. Cross-arg invariant:
``kv_caches`` forbids ``modality.prefix_embeds``,
``modality.output_slice``, ``modality.image_features``,
and ``modality.modality_ids``
(all training-only).
doc_ids: Optional per-token document IDs for packed sequences,
shape ``(batch, seq_len)``. Enables block-diagonal causal
attention that isolates documents within packed sequences.
Returns:
Logits tensor of shape ``(batch, out_seq_len, vocab_size)``
where ``out_seq_len == seq_len`` normally or the sliced
length when ``modality.output_slice`` is set.
"""
inputs_embeds = modality.inputs_embeds if modality is not None else None
prefix_embeds = modality.prefix_embeds if modality is not None else None
output_slice = modality.output_slice if modality is not None else None
image_features = modality.image_features if modality is not None else None
image_mask = modality.image_mask if modality is not None else None
modality_ids = modality.modality_ids if modality is not None else None
if (tokens is None) == (inputs_embeds is None):
raise ValueError(
"Transformer.forward requires exactly one of tokens or modality.inputs_embeds"
)
if kv_caches is not None:
if output_slice is not None:
raise ValueError(
"modality.output_slice is training-only; cannot be combined with kv_caches"
)
if prefix_embeds is not None:
# If we ever allowed both, the RoPE slice below would start at
# start_pos (the cache's text fill level) but seq_len would
# include the prefix — positions would double-count the prefix
# at every decode step. Training-only.
raise ValueError(
"modality.prefix_embeds is training-only; cannot be combined with kv_caches"
)
if image_features is not None:
raise ValueError(
"modality.image_features is training-only; cannot be combined with kv_caches"
)
if modality_ids is not None:
# MoT routes per-token through per-modality projections via
# modality_ids; KV-cache decode has no equivalent semantics
# (cache positions are pre-routed). Training-only for v1.
raise ValueError(
"modality.modality_ids is training-only; cannot be combined with kv_caches"
)
# modality_ids dtype/shape checks. Shape against the residual seq_len
# is checked downstream (when the residual is built); dtype is
# checked here so the error fires before any compute.
if modality_ids is not None and modality_ids.dtype != torch.long:
raise ValueError(
f"modality.modality_ids.dtype must be torch.long, got {modality_ids.dtype}"
)
h = (
self.token_embedding(tokens) # type: ignore[reportOptionalCall]
if tokens is not None
else inputs_embeds
)
assert h is not None # narrowed by the XOR check above
if prefix_embeds is not None:
# Cast to the text-embedding dtype so the concat does not promote.
if prefix_embeds.dtype != h.dtype:
prefix_embeds = prefix_embeds.to(h.dtype)
h = torch.cat([prefix_embeds, h], dim=1)
seq_len = h.shape[1]
# Determine position offset from KV cache fill level
start_pos = kv_caches[0].seq_len if kv_caches is not None else 0
# Slice RoPE frequencies for current positions (device transfer cached)
if self._rope_cos.device != h.device: # type: ignore[reportOptionalMemberAccess]
self._rope_cos = self._rope_cos.to(h.device) # type: ignore[reportOptionalMemberAccess]
self._rope_sin = self._rope_sin.to(h.device) # type: ignore[reportOptionalMemberAccess]
cos = self._rope_cos[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript]
sin = self._rope_sin[start_pos : start_pos + seq_len] # type: ignore[reportOptionalSubscript]
# MoT path: position-based image-then-text split, per-modality
# streams through the MoTBlock stack, single global SDPA per
# layer. modality_ids is required (presence + shape checked
# against the residual). v1 uses position-based routing; the
# tags are validated for shape but not value-matched against
# positions, so a future per-token scatter/gather can land
# without changing the public interface.
if self._mot_modalities:
if modality_ids is None:
raise ValueError(
"MoT model requires modality.modality_ids (got None). Build the "
"ModalityContext via MoTStrategy or set modality_ids explicitly."
)
if modality_ids.shape != h.shape[:2]:
raise ValueError(
f"modality.modality_ids shape {tuple(modality_ids.shape)} does not "
f"match residual shape {tuple(h.shape[:2])}"
)
n_image = self._mot_n_image
t_image = n_image
t_text = h.shape[1] - n_image
streams: dict[str, torch.Tensor] = {
"image": h[:, :t_image, :],
"text": h[:, t_image:, :],
}
# Per-modality RoPE: each modality counts positions from 0
# within its own axis. Image and text share the same RoPE
# table since head_dim is shared.
rope = {
"image": (cos[:t_image], sin[:t_image]),
"text": (cos[:t_text], sin[:t_text]),
}
for layer in self.layers.values():
streams = layer(streams, rope)
streams = {m: self.mot_norms[m](streams[m]) for m in self._mot_modalities}
# Re-concat in image-then-text order to match the residual
# layout the rest of forward expects (output_slice + head).
h = torch.cat([streams["image"], streams["text"]], dim=1)
else:
# Transformer blocks. When the model has cross-attention layers
# (CrossAttentionConfig + nonzero cadence), a CrossAttentionBlock
# fires after text block index i iff (i+1) % _ca_cadence == 0.
# _ca_cadence == 0 (text-only / Joint-Decoder) makes the inner
# branch dead, so the JD path stays bit-equal to today's.
ca_iter = iter(self.cross_attention_layers.values()) if self._ca_cadence else None
for i, layer in enumerate(self.layers.values()):
cache = kv_caches[i] if kv_caches is not None else None
h = layer(h, cos, sin, kv_cache=cache, doc_ids=doc_ids)
if ca_iter is not None and (i + 1) % self._ca_cadence == 0:
ca = next(ca_iter, None)
if ca is not None:
if image_features is None:
raise ValueError(
"Cross-Attention block fired but modality.image_features is None. "
"Cross-Attention models require image_features in the "
"ModalityContext."
)
if image_features.dtype != h.dtype:
image_features = image_features.to(h.dtype)
h = ca(h, image_features, image_mask)
# Final norm
h = self.norm(h)
# Optional slice before the output head (training-only kwarg)
if output_slice is not None:
h = h[:, output_slice, :]
# Output projection
if self.output_head is not None:
h = self.output_head(h)
return h
@registry.register_model("transformer")
def _build_transformer(model_config: ModelConfig) -> Transformer:
return Transformer(model_config)