kempnerforge.model.vlm¶
Vision-language model wrapper.
The wrapper composes a VisionEncoder (HF or test stub), a registered
adapter (MLP2LayerAdapter by default; LinearAdapter available
via the adapter registry) projecting image features into the LLM
embedding space, and the existing Transformer. The arch-specific
work (composing pixel_values + input_ids into a
ModalityContext) lives on a ModalityStrategy that the wrapper
holds, so adding a new arch is one new strategy decorator on
@registry.register_modality_strategy plus one new VLMConfig
subclass, and adding a new adapter is one new builder under
@registry.register_adapter. No edits to VLMWrapper.forward,
no isinstance ladder.
Strategies registered today:
"joint_decoder"— image embeds prepended to the text sequence viaModalityContext.prefix_embeds+output_slice."cross_attention"— image embeds passed viaModalityContext.image_featuresto theCrossAttentionBlock``s inside ``Transformer."mot"— Mixture-of-Transformers. Same residual-stream layout as Joint-Decoder (image-then-text concat,output_slicetrims image positions before the head), plus a per-positionmodality_idstag that theMoTBlockstack consumes for routing.
inner_transformer(model) is the explicit unwrap helper used by the
training loop when it needs to reach Transformer-internal state
(set_moe_step, get_moe_aux_loss, …). Callers that expect the
raw Transformer interface pipe through this helper rather than
relying on attribute fallthrough on VLMWrapper.
Functions
Resolve |
|
|
Build a |
|
Return the underlying |
Classes
Cross-Attention: image embeds flow as K/V into separate cross-attention blocks inside the transformer; the residual stream itself carries text only. |
|
Joint-Decoder: image embeds prepended to the text sequence. |
|
Mixture-of-Transformers: image-then-text residual layout (same as Joint-Decoder) plus a per-position |
|
Composes raw VLM inputs into a |
|
VLM wrapper, arch-driven by a |
- class kempnerforge.model.vlm.ModalityStrategy[source]¶
Bases:
ProtocolComposes raw VLM inputs into a
ModalityContext. One strategy per arch, registered via@registry.register_modality_strategy.Strategies are stateless (hold no parameters) and read submodules off the
VLMWrapperthey receive. They are NOT registered as submodules of the wrapper, so FSDP2 does not wrap them and DCP does not serialize them.- prepare(wrapper, pixel_values, input_ids)[source]¶
- Parameters:
wrapper (VLMWrapper)
pixel_values (torch.Tensor)
input_ids (torch.Tensor)
- Return type:
- num_image_tokens(wrapper)[source]¶
- Parameters:
wrapper (VLMWrapper)
- Return type:
- __init__(*args, **kwargs)¶
- class kempnerforge.model.vlm.JointDecoderStrategy[source]¶
Bases:
objectJoint-Decoder: image embeds prepended to the text sequence.
Forward path:
feats = vision_encoder(pixel_values);img_embeds = adapter(feats);ModalityContext(prefix_embeds, output_slice). The transformer runs over the concatenated(image, text)sequence andoutput_slicetrims the image positions before the LM head.- prepare(wrapper, pixel_values, input_ids)[source]¶
- Parameters:
wrapper (VLMWrapper)
pixel_values (torch.Tensor)
input_ids (torch.Tensor)
- Return type:
- num_image_tokens(wrapper)[source]¶
- Parameters:
wrapper (VLMWrapper)
- Return type:
- class kempnerforge.model.vlm.CrossAttentionStrategy[source]¶
Bases:
objectCross-Attention: image embeds flow as K/V into separate cross-attention blocks inside the transformer; the residual stream itself carries text only.
Forward path:
feats = vision_encoder(pixel_values);img_embeds = adapter(feats);ModalityContext(image_features, image_mask=None).image_mask=Nonemeans “all image tokens valid”; multi-image variants will fill it in later.- prepare(wrapper, pixel_values, input_ids)[source]¶
- Parameters:
wrapper (VLMWrapper)
pixel_values (torch.Tensor)
input_ids (torch.Tensor)
- Return type:
- num_image_tokens(wrapper)[source]¶
- Parameters:
wrapper (VLMWrapper)
- Return type:
- class kempnerforge.model.vlm.MoTStrategy[source]¶
Bases:
objectMixture-of-Transformers: image-then-text residual layout (same as Joint-Decoder) plus a per-position
modality_idstag.Forward path:
feats = vision_encoder(pixel_values);img_embeds = adapter(feats);ModalityContext(prefix_embeds, output_slice, modality_ids).modality_idsis built position-based:0for the firstnum_image_tokenspositions and1for the rest. The MoT forward path uses position-based slicing for v1 routing (the tags are validated for shape but not value-matched against positions), so a future per-token scatter/gather can land without changing the public interface.output_slicetrims the image prefix off the residual before the LM head, matchingJointDecoderStrategy.- prepare(wrapper, pixel_values, input_ids)[source]¶
- Parameters:
wrapper (VLMWrapper)
pixel_values (torch.Tensor)
input_ids (torch.Tensor)
- Return type:
- num_image_tokens(wrapper)[source]¶
- Parameters:
wrapper (VLMWrapper)
- Return type:
- kempnerforge.model.vlm.build_modality_strategy(vlm)[source]¶
Resolve
vlm.archto its registeredModalityStrategy.Pure registry lookup; no
isinstanceladder, no special-cases. Adding a new arch is a single@registry.register_modality_strategydecorator on a new strategy class.- Parameters:
vlm (VLMConfig)
- Return type:
- class kempnerforge.model.vlm.VLMWrapper[source]¶
Bases:
ModuleVLM wrapper, arch-driven by a
ModalityStrategy.Forward:
(pixel_values, input_ids, labels) -> (logits, labels). The strategy composes aModalityContextfrom the raw inputs and the wrapper’s submodules;Transformer.forwardconsumes the context.num_image_tokensis arch-aware and delegates to the strategy.- __init__(vision_encoder, adapter, transformer, strategy)[source]¶
- Parameters:
vision_encoder (VisionEncoder)
adapter (torch.nn.Module)
transformer (Transformer)
strategy (ModalityStrategy)
- Return type:
None
- forward(pixel_values, input_ids, labels=None)[source]¶
- Parameters:
pixel_values (torch.Tensor)
input_ids (torch.Tensor)
labels (torch.Tensor | None)
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- kempnerforge.model.vlm.inner_transformer(model)[source]¶
Return the underlying
Transformer, unwrappingVLMWrapperandtorch.compile.Training-loop call sites that need to reach Transformer internals (
set_moe_step,get_moe_aux_loss,get_expert_counts, future methods) route through this helper. Explicit unwrap is predictable undertorch.compileand FSDP2 wrapping; it also makes the VLM branch visible at the call site rather than buried in__getattr__.- Parameters:
model (torch.nn.Module)
- Return type:
- kempnerforge.model.vlm.build_vlm_wrapper(model_config, vision_config, adapter_config, vlm_config)[source]¶
Build a
VLMWrapperfrom the four top-level configs.Used by tests and by
build_parallel_model. Constructs the vision encoder via the registry (HF weights loaded on CPU), builds an adapter via theadapterregistry at the LLMdim, looks up the rightModalityStrategyby arch, and composes them with a rawTransformer. Callers that need meta-device / FSDP / freeze handling go throughbuild_parallel_modelinstead.All four configs are required: the schema flip lifted the vision / adapter / VLM sections out of
ModelConfigand made them parallel siblings.- Parameters:
model_config (ModelConfig)
vision_config (VisionEncoderConfig)
adapter_config (AdapterConfig)
vlm_config (VLMConfig)
- Return type: