Source code for kempnerforge.config.vlm

"""VLM (vision-language model) configuration.

``VLMConfig`` carries the arch-level knobs of the vision-language
model: which architecture to wire (``arch``), the fixed text padding
length, and the freeze policy. The vision encoder and adapter are
described by sibling top-level sections (``VisionEncoderConfig`` in
``config/vision.py``, ``AdapterConfig`` in ``config/adapter.py``).

In TOML, ``[vlm]`` is a top-level section, parallel to ``[model]``,
``[vision_encoder]``, and ``[adapter]``. When ``[vlm]`` is absent the
job is a pure text run.

Architecture is a discriminated union on the ``arch`` field:

- ``"joint_decoder"`` image tokens prepended to the text sequence.
- ``"cross_attention"`` image K/V flows in via separate
  cross-attention blocks at a configurable cadence.
- ``"mot"`` Mixture-of-Transformers: per-modality Q/K/V/O + per-
  modality FFN at every layer, single global self-attention.

Each arch gets its own ``VLMConfig`` subclass, registered via
``registry.register_vlm_config``. The TOML loader dispatches on
``arch`` to instantiate the right subclass; programmatic callers use
``VLMConfig.for_arch(arch_name, **fields)``.

``FreezeSpec`` / ``FreezeStage`` are consumed by
``kempnerforge/training/freeze.py``.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from kempnerforge.config.registry import registry

# Default aliases: module-name keys expand to fnmatch patterns. Each alias
# matches both the bare module attribute (a hypothetical param directly on
# ``adapter``/``vision_encoder`` etc.) and all nested children, so freezing
# by alias cannot silently miss a parameter.
DEFAULT_MODULE_PATTERNS: dict[str, list[str]] = {
    "transformer": ["transformer", "transformer.*"],
    "vision_encoder": ["vision_encoder", "vision_encoder.*"],
    "adapter": ["adapter", "adapter.*"],
}


[docs] @dataclass(frozen=True) class FreezeSpec: """A single freeze directive. ``module`` is an alias (key in a pattern map such as ``DEFAULT_MODULE_PATTERNS``) or a raw fnmatch pattern matching fully-qualified parameter names. """ module: str frozen: bool = True
[docs] @dataclass(frozen=True) class FreezeStage: """A freeze directive that applies from ``start_step`` onward. Used for staged training recipes where the trainable subset changes across training phases. The list of stages on ``VLMConfig`` is expected to be in strictly monotonic ``start_step`` order. """ start_step: int specs: tuple[FreezeSpec, ...]
# Reserved arches: known names not yet implemented. Loader/for_arch raise # ``NotImplementedError`` rather than ``ValueError`` so TOMLs that aim at # a future arch get a clear message about it. _RESERVED_ARCHS: tuple[str, ...] = ()
[docs] @dataclass class VLMConfig: """Base VLM configuration. Subclasses register themselves via ``@registry.register_vlm_config`` and override the ``arch`` field's default. Use ``VLMConfig.for_arch(arch_name, **fields)`` to construct programmatically; the TOML loader dispatches on ``arch`` automatically. Field summary (full per-field docs are picked up from autodoc): - ``arch`` VLM architecture discriminator. Subclasses set this via field default; direct construction with an arch name not backed by a registered subclass raises. - ``max_text_len`` fixed text padding length used by ``VLMCollator``. Enforces rank-consistent batches under FSDP2. - ``freeze`` static freeze specs applied once at build time. - ``freeze_schedule`` step-boundary freeze transitions. - ``module_patterns`` map of module alias (``"transformer"``, ``"vision_encoder"``, ``"adapter"``, plus arch-specific additions) to fnmatch pattern list. """ arch: str = "joint_decoder" max_text_len: int = 512 freeze: list[FreezeSpec] = field(default_factory=lambda: [FreezeSpec("vision_encoder", True)]) freeze_schedule: list[FreezeStage] = field(default_factory=list) module_patterns: dict[str, list[str]] = field( default_factory=lambda: {k: list(v) for k, v in DEFAULT_MODULE_PATTERNS.items()} ) def __post_init__(self) -> None: if self.arch in _RESERVED_ARCHS: raise NotImplementedError( f"vlm.arch={self.arch!r} is reserved; not yet implemented. " f"Reserved: {sorted(_RESERVED_ARCHS)}." ) registered = tuple(registry.list_vlm_configs()) if self.arch not in registered: raise ValueError( f"Unknown vlm.arch: {self.arch!r}. " f"Registered: {sorted(registered)}. " f"Reserved (not yet implemented): {sorted(_RESERVED_ARCHS)}." ) if self.max_text_len <= 0: raise ValueError("vlm.max_text_len must be positive") if self.freeze_schedule: steps = [s.start_step for s in self.freeze_schedule] if steps != sorted(steps) or len(steps) != len(set(steps)): raise ValueError("vlm.freeze_schedule start_steps must be strictly monotonic")
[docs] def residual_stream_image_tokens(self, num_tokens: int) -> int: """Number of image tokens this arch places in the residual stream. Used by ``JobConfig`` to validate that ``model.max_seq_len`` and ``train.seq_len`` are large enough to fit ``residual_stream_image_tokens + max_text_len`` along the attention sequence dimension. - Joint-Decoder / MoT: ``num_tokens`` (image tokens prepended to text). - Cross-Attention: ``0`` (residual stream is text-only; image features flow side-channel into CA blocks). Args: num_tokens: The vision encoder's resolved ``num_tokens``. Pass ``0`` when it is not known yet (the "infer at build time" sentinel); cross-checks that depend on a concrete value will skip and re-run at build time. """ return num_tokens
[docs] @classmethod def for_arch(cls, arch: str, **kwargs: Any) -> VLMConfig: """Resolve ``arch`` to a registered subclass and instantiate. Raises: ValueError: ``arch`` is not registered. NotImplementedError: ``arch`` is reserved (in ``_RESERVED_ARCHS``) matches loader semantics so the error type is independent of construction site. Example: >>> cfg = VLMConfig.for_arch( ... "cross_attention", ... max_text_len=2048, ... cross_attention_every_n_layers=4, ... ) """ if arch in _RESERVED_ARCHS: raise NotImplementedError( f"vlm.arch={arch!r} is reserved; not yet implemented. " f"Reserved: {sorted(_RESERVED_ARCHS)}." ) try: sub = registry.get_vlm_config(arch) except KeyError as e: raise ValueError( f"Unknown vlm.arch: {arch!r}. " f"Registered: {sorted(registry.list_vlm_configs())}. " f"Reserved (not yet implemented): {sorted(_RESERVED_ARCHS)}." ) from e return sub(**kwargs)
[docs] @registry.register_vlm_config("joint_decoder") @dataclass class JointDecoderConfig(VLMConfig): """Joint-Decoder: image tokens prepended to the text sequence. No additional fields beyond ``VLMConfig``. The arch is wired through ``VLMWrapper`` + ``ModalityContext.prefix_embeds`` + ``output_slice``. """ arch: str = "joint_decoder"
[docs] @registry.register_vlm_config("cross_attention") @dataclass class CrossAttentionConfig(VLMConfig): """Cross-Attention: image K/V flows into separate cross-attention blocks inserted at a configurable cadence. The CA-specific module alias ``"cross_attention"`` is added to ``module_patterns`` so freeze targeting works out of the box. """ arch: str = "cross_attention" cross_attention_every_n_layers: int = 4 cross_attention_n_heads: int = 0 cross_attention_n_kv_heads: int = 0 module_patterns: dict[str, list[str]] = field( default_factory=lambda: { **{k: list(v) for k, v in DEFAULT_MODULE_PATTERNS.items()}, "cross_attention": [ "transformer.cross_attention_layers", "transformer.cross_attention_layers.*", ], } ) def __post_init__(self) -> None: super().__post_init__() if self.cross_attention_every_n_layers <= 0: raise ValueError( "vlm.cross_attention_every_n_layers must be positive " f"(got {self.cross_attention_every_n_layers})" ) if self.cross_attention_n_heads < 0 or self.cross_attention_n_kv_heads < 0: raise ValueError( "vlm.cross_attention_n_heads and cross_attention_n_kv_heads must be non-negative" )
[docs] def residual_stream_image_tokens(self, num_tokens: int) -> int: # noqa: ARG002 """Cross-Attention does not extend the residual stream. Image features flow as K/V into separate CrossAttentionBlocks; the residual itself carries text only. So the seq_len cross-check skips ``num_tokens`` and just enforces ``seq_len >= max_text_len``. The ``num_tokens`` argument is accepted for signature parity with the base method but ignored. """ return 0
[docs] def resolved_heads(self, model_n_heads: int) -> tuple[int, int]: """Resolve zero-defaults against the text backbone's head count. Returns ``(n_heads, n_kv_heads)`` such that the ``CrossAttentionBlock`` constructor never observes 0. Resolution rule: - ``n_heads = self.cross_attention_n_heads or model_n_heads`` - ``n_kv_heads = self.cross_attention_n_kv_heads or n_heads`` """ if model_n_heads <= 0: raise ValueError(f"model_n_heads must be positive (got {model_n_heads})") n_heads = self.cross_attention_n_heads or model_n_heads n_kv_heads = self.cross_attention_n_kv_heads or n_heads return n_heads, n_kv_heads
[docs] @registry.register_vlm_config("mot") @dataclass class MoTConfig(VLMConfig): """Mixture-of-Transformers: per-modality Q/K/V/O projections + per- modality FFN at every layer; single global self-attention mixes all modality streams (Liang et al. 2024, Algorithm 1). Image tokens are prepended to the text sequence in the residual stream (image-then-text concat order). ``modality_ids`` tags every position with its source modality; the operator routes per-token through the per-modality projection / FFN copy for that position. The MoT-specific module alias ``"mot"`` is added to ``module_patterns`` so freeze targeting works out of the box: ``FreezeSpec("mot", True)`` freezes the per-modality main stack (``transformer.layers.*``) without touching the embedding / output head / final norms. """ arch: str = "mot" mot_modalities: tuple[str, ...] = ("image", "text") mot_image_n_heads: int = 0 mot_image_n_kv_heads: int = 0 mot_warm_start_from_text: bool = False mot_warm_start_path: str = "" module_patterns: dict[str, list[str]] = field( default_factory=lambda: { **{k: list(v) for k, v in DEFAULT_MODULE_PATTERNS.items()}, "mot": [ "transformer.layers", "transformer.layers.*", ], } ) def __post_init__(self) -> None: super().__post_init__() if len(self.mot_modalities) < 2: raise ValueError( f"vlm.mot_modalities must have at least 2 entries (got {self.mot_modalities!r})" ) if "text" not in self.mot_modalities: raise ValueError( f"vlm.mot_modalities must include 'text' (got {self.mot_modalities!r})" ) if "image" not in self.mot_modalities: raise ValueError( f"vlm.mot_modalities must include 'image' (got {self.mot_modalities!r})" ) if len(set(self.mot_modalities)) != len(self.mot_modalities): raise ValueError( f"vlm.mot_modalities must not contain duplicates (got {self.mot_modalities!r})" ) if self.mot_image_n_heads < 0 or self.mot_image_n_kv_heads < 0: raise ValueError("vlm.mot_image_n_heads and mot_image_n_kv_heads must be non-negative") if self.mot_warm_start_from_text and not self.mot_warm_start_path: raise ValueError( "vlm.mot_warm_start_from_text=True requires vlm.mot_warm_start_path to be a " "non-empty filesystem path to a torch-saved JD or text-only state dict" )
[docs] def residual_stream_image_tokens(self, num_tokens: int) -> int: """MoT prepends ``num_tokens`` image tokens to the text sequence (same residual-stream layout as Joint-Decoder). """ return num_tokens
[docs] def resolved_image_heads( self, model_n_heads: int, model_n_kv_heads: int = 0 ) -> tuple[int, int]: """Resolve zero-defaults against the text backbone's head counts. Returns ``(n_heads, n_kv_heads)`` such that the operator's per-modality projection sizes are never built from 0. Resolution rule: - ``n_heads = self.mot_image_n_heads or model_n_heads`` - ``n_kv_heads = self.mot_image_n_kv_heads or model_n_kv_heads or n_heads`` v1 note: the global-SDPA design requires equal head counts across modalities; ``Transformer.__init__`` asserts the resolved tuple matches the text backbone (raise on per-modality override). Field is present so a future per-modality relaxation can land without a config-shape change. """ if model_n_heads <= 0: raise ValueError(f"model_n_heads must be positive (got {model_n_heads})") n_heads = self.mot_image_n_heads or model_n_heads n_kv_heads = self.mot_image_n_kv_heads or model_n_kv_heads or n_heads return n_heads, n_kv_heads