"""Vision-language model wrapper.
The wrapper composes a ``VisionEncoder`` (HF or test stub), a registered
adapter (``MLP2LayerAdapter`` by default; ``LinearAdapter`` available
via the ``adapter`` registry) projecting image features into the LLM
embedding space, and the existing ``Transformer``. The arch-specific
work (composing ``pixel_values`` + ``input_ids`` into a
``ModalityContext``) lives on a ``ModalityStrategy`` that the wrapper
holds, so adding a new arch is one new strategy decorator on
``@registry.register_modality_strategy`` plus one new ``VLMConfig``
subclass, and adding a new adapter is one new builder under
``@registry.register_adapter``. No edits to ``VLMWrapper.forward``,
no ``isinstance`` ladder.
Strategies registered today:
- ``"joint_decoder"`` โ image embeds prepended to the text sequence
via ``ModalityContext.prefix_embeds`` + ``output_slice``.
- ``"cross_attention"`` โ image embeds passed via
``ModalityContext.image_features`` to the ``CrossAttentionBlock``s
inside ``Transformer``.
- ``"mot"`` โ Mixture-of-Transformers. Same residual-stream layout as
Joint-Decoder (image-then-text concat, ``output_slice`` trims image
positions before the head), plus a per-position ``modality_ids``
tag that the ``MoTBlock`` stack consumes for routing.
``inner_transformer(model)`` is the explicit unwrap helper used by the
training loop when it needs to reach Transformer-internal state
(``set_moe_step``, ``get_moe_aux_loss``, ...). Callers that expect the
raw ``Transformer`` interface pipe through this helper rather than
relying on attribute fallthrough on ``VLMWrapper``.
"""
from __future__ import annotations
from collections.abc import Iterable
from typing import Protocol
import torch
import torch.nn as nn
from kempnerforge.config.adapter import AdapterConfig
from kempnerforge.config.registry import registry
from kempnerforge.config.schema import ModelConfig
from kempnerforge.config.vision import VisionEncoderConfig
from kempnerforge.config.vlm import FreezeSpec, VLMConfig
from kempnerforge.model.adapter import build_adapter
from kempnerforge.model.modality import ModalityContext
from kempnerforge.model.transformer import Transformer
from kempnerforge.model.vision import VisionEncoder
[docs]
class ModalityStrategy(Protocol):
"""Composes raw VLM inputs into a ``ModalityContext``. One strategy
per arch, registered via ``@registry.register_modality_strategy``.
Strategies are stateless (hold no parameters) and read submodules
off the ``VLMWrapper`` they receive. They are NOT registered as
submodules of the wrapper, so FSDP2 does not wrap them and DCP
does not serialize them.
"""
self,
wrapper: VLMWrapper,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
) -> ModalityContext: ...
[docs]
def num_image_tokens(self, wrapper: VLMWrapper) -> int: ...
def _project_image_features(wrapper: VLMWrapper, pixel_values: torch.Tensor) -> torch.Tensor:
"""Encode + adapt image features. Cast at the encoder/adapter
boundary so the encoder can stay in its HF dtype (often fp32) while
the adapter and transformer run in bf16 without an inner dtype
clash.
"""
feats = wrapper.vision_encoder(pixel_values)
# Adapter-agnostic dtype lookup: use the first adapter parameter's
# dtype so any registered adapter (mlp_2layer, linear, ...) works
# without coupling to a specific submodule attribute.
adapter_dtype = next(wrapper.adapter.parameters()).dtype
if feats.dtype != adapter_dtype:
feats = feats.to(adapter_dtype)
return wrapper.adapter(feats)
[docs]
@registry.register_modality_strategy("joint_decoder")
class JointDecoderStrategy:
"""Joint-Decoder: image embeds prepended to the text sequence.
Forward path: ``feats = vision_encoder(pixel_values)``;
``img_embeds = adapter(feats)``; ``ModalityContext(prefix_embeds,
output_slice)``. The transformer runs over the concatenated
``(image, text)`` sequence and ``output_slice`` trims the image
positions before the LM head.
"""
[docs]
def prepare(
self,
wrapper: VLMWrapper,
pixel_values: torch.Tensor,
input_ids: torch.Tensor, # noqa: ARG002
) -> ModalityContext:
img_embeds = _project_image_features(wrapper, pixel_values)
n = wrapper.vision_encoder.num_tokens
return ModalityContext(prefix_embeds=img_embeds, output_slice=slice(n, None))
[docs]
def num_image_tokens(self, wrapper: VLMWrapper) -> int:
return wrapper.vision_encoder.num_tokens
[docs]
@registry.register_modality_strategy("cross_attention")
class CrossAttentionStrategy:
"""Cross-Attention: image embeds flow as K/V into separate
cross-attention blocks inside the transformer; the residual stream
itself carries text only.
Forward path: ``feats = vision_encoder(pixel_values)``;
``img_embeds = adapter(feats)``; ``ModalityContext(image_features,
image_mask=None)``. ``image_mask=None`` means "all image tokens
valid"; multi-image variants will fill it in later.
"""
[docs]
def prepare(
self,
wrapper: VLMWrapper,
pixel_values: torch.Tensor,
input_ids: torch.Tensor, # noqa: ARG002
) -> ModalityContext:
img_embeds = _project_image_features(wrapper, pixel_values)
return ModalityContext(image_features=img_embeds, image_mask=None)
[docs]
def num_image_tokens(self, wrapper: VLMWrapper) -> int: # noqa: ARG002
# Cross-Attention does not extend the residual stream.
return 0
[docs]
@registry.register_modality_strategy("mot")
class MoTStrategy:
"""Mixture-of-Transformers: image-then-text residual layout (same as
Joint-Decoder) plus a per-position ``modality_ids`` tag.
Forward path: ``feats = vision_encoder(pixel_values)``;
``img_embeds = adapter(feats)``;
``ModalityContext(prefix_embeds, output_slice, modality_ids)``.
``modality_ids`` is built position-based: ``0`` for the first
``num_image_tokens`` positions and ``1`` for the rest. The MoT
forward path uses position-based slicing for v1 routing (the tags
are validated for shape but not value-matched against positions),
so a future per-token scatter/gather can land without changing the
public interface.
``output_slice`` trims the image prefix off the residual before
the LM head, matching ``JointDecoderStrategy``.
"""
[docs]
def prepare(
self,
wrapper: VLMWrapper,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
) -> ModalityContext:
img_embeds = _project_image_features(wrapper, pixel_values)
n = wrapper.vision_encoder.num_tokens
b, t_text = input_ids.shape
modality_ids = torch.zeros(b, n + t_text, dtype=torch.long, device=input_ids.device)
modality_ids[:, n:] = 1
return ModalityContext(
prefix_embeds=img_embeds,
output_slice=slice(n, None),
modality_ids=modality_ids,
)
[docs]
def num_image_tokens(self, wrapper: VLMWrapper) -> int:
return wrapper.vision_encoder.num_tokens
[docs]
def build_modality_strategy(vlm: VLMConfig) -> ModalityStrategy:
"""Resolve ``vlm.arch`` to its registered ``ModalityStrategy``.
Pure registry lookup; no ``isinstance`` ladder, no special-cases.
Adding a new arch is a single ``@registry.register_modality_strategy``
decorator on a new strategy class.
"""
return registry.get_modality_strategy(vlm.arch)()
[docs]
class VLMWrapper(nn.Module):
"""VLM wrapper, arch-driven by a ``ModalityStrategy``.
Forward: ``(pixel_values, input_ids, labels) -> (logits, labels)``.
The strategy composes a ``ModalityContext`` from the raw inputs
and the wrapper's submodules; ``Transformer.forward`` consumes the
context. ``num_image_tokens`` is arch-aware and delegates to the
strategy.
"""
[docs]
def __init__(
self,
vision_encoder: VisionEncoder,
adapter: nn.Module,
transformer: Transformer,
strategy: ModalityStrategy,
) -> None:
super().__init__()
self.vision_encoder = vision_encoder
self.adapter = adapter
self.transformer = transformer
# Strategy is a plain Python object (not nn.Module). nn.Module's
# __setattr__ only routes Module/Parameter/Tensor attributes into
# _modules/_parameters/_buffers, so plain objects are stored as
# ordinary attributes and stay outside the module tree by
# default. FSDP2 wraps modules; DCP serializes module params;
# neither sees the strategy. test_strategy_not_in_module_tree
# pins this contract.
self.strategy = strategy
@property
def num_image_tokens(self) -> int:
return self.strategy.num_image_tokens(self)
[docs]
def forward(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Route the text embedding through Transformer.forward so FSDP2's
# per-module hook intercepts the token_embedding call and
# materializes the DTensor weight before F.embedding runs. Doing
# the embedding externally (transformer.token_embedding(input_ids))
# bypasses FSDP and fails with "mixed torch.Tensor and DTensor".
modality = self.strategy.prepare(self, pixel_values, input_ids)
logits = self.transformer(tokens=input_ids, modality=modality)
return logits, labels
def _is_encoder_frozen(specs: Iterable[FreezeSpec]) -> bool:
"""True iff every freeze spec targeting the vision encoder is ``frozen=True``.
Partial unfreezes (e.g. ``FreezeSpec("vision_encoder.layers.11", False)``)
leave the encoder partly trainable and return False, so ``build_vlm``
keeps the encoder in ``.train()`` mode and does not replicate it as a
cheap full copy under FSDP2.
"""
relevant = [
s for s in specs if s.module == "vision_encoder" or s.module.startswith("vision_encoder.")
]
if not relevant:
return False
return all(s.frozen for s in relevant)
[docs]
def build_vlm_wrapper(
model_config: ModelConfig,
vision_config: VisionEncoderConfig,
adapter_config: AdapterConfig,
vlm_config: VLMConfig,
) -> VLMWrapper:
"""Build a ``VLMWrapper`` from the four top-level configs.
Used by tests and by ``build_parallel_model``. Constructs the
vision encoder via the registry (HF weights loaded on CPU), builds
an adapter via the ``adapter`` registry at the LLM ``dim``, looks
up the right ``ModalityStrategy`` by arch, and composes them with
a raw ``Transformer``. Callers that need meta-device / FSDP /
freeze handling go through ``build_parallel_model`` instead.
All four configs are required: the schema flip lifted the vision /
adapter / VLM sections out of ``ModelConfig`` and made them parallel
siblings.
"""
encoder_builder = registry.get_vision_encoder(vision_config.type)
encoder = encoder_builder(
vision_config.path,
num_tokens=vision_config.num_tokens if vision_config.num_tokens > 0 else None,
feature_dim=vision_config.feature_dim if vision_config.feature_dim > 0 else None,
)
# Build-time max_seq_len cross-check using the encoder's resolved
# num_tokens. ``JobConfig.__post_init__`` runs the same check at config
# time only when ``vision_encoder.num_tokens > 0``; when the user leaves
# num_tokens=0 (the "infer from encoder at build time" sentinel) the
# config-time check is skipped and the residual-stream allocation goes
# unchecked until the model actually runs. This guard fills that gap.
residual_image_tokens = vlm_config.residual_stream_image_tokens(encoder.num_tokens)
required = residual_image_tokens + vlm_config.max_text_len
if model_config.max_seq_len < required:
raise ValueError(
f"max_seq_len ({model_config.max_seq_len}) insufficient for VLM at build time: "
f"encoder.num_tokens ({encoder.num_tokens}) -> "
f"residual_image_tokens ({residual_image_tokens}) + "
f"vlm.max_text_len ({vlm_config.max_text_len}) = {required}"
)
in_dim = vision_config.feature_dim or encoder.feature_dim
adapter = build_adapter(adapter_config, in_dim=in_dim, out_dim=model_config.dim)
transformer = Transformer(
model_config, vlm_config=vlm_config, num_image_tokens=encoder.num_tokens
)
strategy = build_modality_strategy(vlm_config)
return VLMWrapper(encoder, adapter, transformer, strategy)