kempnerforge.model.vlm

Vision-language model wrapper.

The wrapper composes a VisionEncoder (HF or test stub), a registered adapter (MLP2LayerAdapter by default; LinearAdapter available via the adapter registry) projecting image features into the LLM embedding space, and the existing Transformer. The arch-specific work (composing pixel_values + input_ids into a ModalityContext) lives on a ModalityStrategy that the wrapper holds, so adding a new arch is one new strategy decorator on @registry.register_modality_strategy plus one new VLMConfig subclass, and adding a new adapter is one new builder under @registry.register_adapter. No edits to VLMWrapper.forward, no isinstance ladder.

Strategies registered today:

  • "joint_decoder" — image embeds prepended to the text sequence via ModalityContext.prefix_embeds + output_slice.

  • "cross_attention" — image embeds passed via ModalityContext.image_features to the CrossAttentionBlock``s inside ``Transformer.

  • "mot" — Mixture-of-Transformers. Same residual-stream layout as Joint-Decoder (image-then-text concat, output_slice trims image positions before the head), plus a per-position modality_ids tag that the MoTBlock stack consumes for routing.

inner_transformer(model) is the explicit unwrap helper used by the training loop when it needs to reach Transformer-internal state (set_moe_step, get_moe_aux_loss, …). Callers that expect the raw Transformer interface pipe through this helper rather than relying on attribute fallthrough on VLMWrapper.

Functions

build_modality_strategy(vlm)

Resolve vlm.arch to its registered ModalityStrategy.

build_vlm_wrapper(model_config, ...)

Build a VLMWrapper from the four top-level configs.

inner_transformer(model)

Return the underlying Transformer, unwrapping VLMWrapper and torch.compile.

Classes

CrossAttentionStrategy

Cross-Attention: image embeds flow as K/V into separate cross-attention blocks inside the transformer; the residual stream itself carries text only.

JointDecoderStrategy

Joint-Decoder: image embeds prepended to the text sequence.

MoTStrategy

Mixture-of-Transformers: image-then-text residual layout (same as Joint-Decoder) plus a per-position modality_ids tag.

ModalityStrategy

Composes raw VLM inputs into a ModalityContext.

VLMWrapper

VLM wrapper, arch-driven by a ModalityStrategy.

class kempnerforge.model.vlm.ModalityStrategy[source]

Bases: Protocol

Composes raw VLM inputs into a ModalityContext. One strategy per arch, registered via @registry.register_modality_strategy.

Strategies are stateless (hold no parameters) and read submodules off the VLMWrapper they receive. They are NOT registered as submodules of the wrapper, so FSDP2 does not wrap them and DCP does not serialize them.

prepare(wrapper, pixel_values, input_ids)[source]
Parameters:
Return type:

ModalityContext

num_image_tokens(wrapper)[source]
Parameters:

wrapper (VLMWrapper)

Return type:

int

__init__(*args, **kwargs)
class kempnerforge.model.vlm.JointDecoderStrategy[source]

Bases: object

Joint-Decoder: image embeds prepended to the text sequence.

Forward path: feats = vision_encoder(pixel_values); img_embeds = adapter(feats); ModalityContext(prefix_embeds, output_slice). The transformer runs over the concatenated (image, text) sequence and output_slice trims the image positions before the LM head.

prepare(wrapper, pixel_values, input_ids)[source]
Parameters:
Return type:

ModalityContext

num_image_tokens(wrapper)[source]
Parameters:

wrapper (VLMWrapper)

Return type:

int

class kempnerforge.model.vlm.CrossAttentionStrategy[source]

Bases: object

Cross-Attention: image embeds flow as K/V into separate cross-attention blocks inside the transformer; the residual stream itself carries text only.

Forward path: feats = vision_encoder(pixel_values); img_embeds = adapter(feats); ModalityContext(image_features, image_mask=None). image_mask=None means “all image tokens valid”; multi-image variants will fill it in later.

prepare(wrapper, pixel_values, input_ids)[source]
Parameters:
Return type:

ModalityContext

num_image_tokens(wrapper)[source]
Parameters:

wrapper (VLMWrapper)

Return type:

int

class kempnerforge.model.vlm.MoTStrategy[source]

Bases: object

Mixture-of-Transformers: image-then-text residual layout (same as Joint-Decoder) plus a per-position modality_ids tag.

Forward path: feats = vision_encoder(pixel_values); img_embeds = adapter(feats); ModalityContext(prefix_embeds, output_slice, modality_ids).

modality_ids is built position-based: 0 for the first num_image_tokens positions and 1 for the rest. The MoT forward path uses position-based slicing for v1 routing (the tags are validated for shape but not value-matched against positions), so a future per-token scatter/gather can land without changing the public interface.

output_slice trims the image prefix off the residual before the LM head, matching JointDecoderStrategy.

prepare(wrapper, pixel_values, input_ids)[source]
Parameters:
Return type:

ModalityContext

num_image_tokens(wrapper)[source]
Parameters:

wrapper (VLMWrapper)

Return type:

int

kempnerforge.model.vlm.build_modality_strategy(vlm)[source]

Resolve vlm.arch to its registered ModalityStrategy.

Pure registry lookup; no isinstance ladder, no special-cases. Adding a new arch is a single @registry.register_modality_strategy decorator on a new strategy class.

Parameters:

vlm (VLMConfig)

Return type:

ModalityStrategy

class kempnerforge.model.vlm.VLMWrapper[source]

Bases: Module

VLM wrapper, arch-driven by a ModalityStrategy.

Forward: (pixel_values, input_ids, labels) -> (logits, labels). The strategy composes a ModalityContext from the raw inputs and the wrapper’s submodules; Transformer.forward consumes the context. num_image_tokens is arch-aware and delegates to the strategy.

__init__(vision_encoder, adapter, transformer, strategy)[source]
Parameters:
Return type:

None

property num_image_tokens: int
forward(pixel_values, input_ids, labels=None)[source]
Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor | None]

kempnerforge.model.vlm.inner_transformer(model)[source]

Return the underlying Transformer, unwrapping VLMWrapper and torch.compile.

Training-loop call sites that need to reach Transformer internals (set_moe_step, get_moe_aux_loss, get_expert_counts, future methods) route through this helper. Explicit unwrap is predictable under torch.compile and FSDP2 wrapping; it also makes the VLM branch visible at the call site rather than buried in __getattr__.

Parameters:

model (torch.nn.Module)

Return type:

torch.nn.Module

kempnerforge.model.vlm.build_vlm_wrapper(model_config, vision_config, adapter_config, vlm_config)[source]

Build a VLMWrapper from the four top-level configs.

Used by tests and by build_parallel_model. Constructs the vision encoder via the registry (HF weights loaded on CPU), builds an adapter via the adapter registry at the LLM dim, looks up the right ModalityStrategy by arch, and composes them with a raw Transformer. Callers that need meta-device / FSDP / freeze handling go through build_parallel_model instead.

All four configs are required: the schema flip lifted the vision / adapter / VLM sections out of ModelConfig and made them parallel siblings.

Parameters:
Return type:

VLMWrapper