kempnerforge.config.vlm

VLM (vision-language model) configuration.

VLMConfig carries the arch-level knobs of the vision-language model: which architecture to wire (arch), the fixed text padding length, and the freeze policy. The vision encoder and adapter are described by sibling top-level sections (VisionEncoderConfig in config/vision.py, AdapterConfig in config/adapter.py).

In TOML, [vlm] is a top-level section, parallel to [model], [vision_encoder], and [adapter]. When [vlm] is absent the job is a pure text run.

Architecture is a discriminated union on the arch field:

  • "joint_decoder" image tokens prepended to the text sequence.

  • "cross_attention" image K/V flows in via separate cross-attention blocks at a configurable cadence.

  • "mot" Mixture-of-Transformers: per-modality Q/K/V/O + per- modality FFN at every layer, single global self-attention.

Each arch gets its own VLMConfig subclass, registered via registry.register_vlm_config. The TOML loader dispatches on arch to instantiate the right subclass; programmatic callers use VLMConfig.for_arch(arch_name, **fields).

FreezeSpec / FreezeStage are consumed by kempnerforge/training/freeze.py.

Classes

CrossAttentionConfig

Cross-Attention: image K/V flows into separate cross-attention blocks inserted at a configurable cadence.

FreezeSpec

A single freeze directive.

FreezeStage

A freeze directive that applies from start_step onward.

JointDecoderConfig

Joint-Decoder: image tokens prepended to the text sequence.

MoTConfig

Mixture-of-Transformers: per-modality Q/K/V/O projections + per- modality FFN at every layer; single global self-attention mixes all modality streams (Liang et al. 2024, Algorithm 1).

VLMConfig

Base VLM configuration.

class kempnerforge.config.vlm.FreezeSpec[source]

Bases: object

A single freeze directive.

module is an alias (key in a pattern map such as DEFAULT_MODULE_PATTERNS) or a raw fnmatch pattern matching fully-qualified parameter names.

module: str
frozen: bool = True
__init__(module, frozen=True)
Parameters:
Return type:

None

class kempnerforge.config.vlm.FreezeStage[source]

Bases: object

A freeze directive that applies from start_step onward.

Used for staged training recipes where the trainable subset changes across training phases. The list of stages on VLMConfig is expected to be in strictly monotonic start_step order.

start_step: int
specs: tuple[FreezeSpec, ...]
__init__(start_step, specs)
Parameters:
Return type:

None

class kempnerforge.config.vlm.VLMConfig[source]

Bases: object

Base VLM configuration.

Subclasses register themselves via @registry.register_vlm_config and override the arch field’s default. Use VLMConfig.for_arch(arch_name, **fields) to construct programmatically; the TOML loader dispatches on arch automatically.

Field summary (full per-field docs are picked up from autodoc):

  • arch VLM architecture discriminator. Subclasses set this via field default; direct construction with an arch name not backed by a registered subclass raises.

  • max_text_len fixed text padding length used by VLMCollator. Enforces rank-consistent batches under FSDP2.

  • freeze static freeze specs applied once at build time.

  • freeze_schedule step-boundary freeze transitions.

  • module_patterns map of module alias ("transformer", "vision_encoder", "adapter", plus arch-specific additions) to fnmatch pattern list.

arch: str = 'joint_decoder'
max_text_len: int = 512
freeze: list[FreezeSpec]
freeze_schedule: list[FreezeStage]
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

Number of image tokens this arch places in the residual stream.

Used by JobConfig to validate that model.max_seq_len and train.seq_len are large enough to fit residual_stream_image_tokens + max_text_len along the attention sequence dimension.

  • Joint-Decoder / MoT: num_tokens (image tokens prepended to text).

  • Cross-Attention: 0 (residual stream is text-only; image features flow side-channel into CA blocks).

Parameters:

num_tokens (int) – The vision encoder’s resolved num_tokens. Pass 0 when it is not known yet (the “infer at build time” sentinel); cross-checks that depend on a concrete value will skip and re-run at build time.

Return type:

int

classmethod for_arch(arch, **kwargs)[source]

Resolve arch to a registered subclass and instantiate.

Raises:
  • ValueErrorarch is not registered.

  • NotImplementedErrorarch is reserved (in _RESERVED_ARCHS) matches loader semantics so the error type is independent of construction site.

Parameters:
Return type:

VLMConfig

Example

>>> cfg = VLMConfig.for_arch(
...     "cross_attention",
...     max_text_len=2048,
...     cross_attention_every_n_layers=4,
... )
__init__(arch='joint_decoder', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>)
Parameters:
Return type:

None

class kempnerforge.config.vlm.JointDecoderConfig[source]

Bases: VLMConfig

Joint-Decoder: image tokens prepended to the text sequence.

No additional fields beyond VLMConfig. The arch is wired through VLMWrapper + ModalityContext.prefix_embeds + output_slice.

arch: str = 'joint_decoder'
__init__(arch='joint_decoder', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>)
Parameters:
Return type:

None

class kempnerforge.config.vlm.CrossAttentionConfig[source]

Bases: VLMConfig

Cross-Attention: image K/V flows into separate cross-attention blocks inserted at a configurable cadence.

The CA-specific module alias "cross_attention" is added to module_patterns so freeze targeting works out of the box.

arch: str = 'cross_attention'
cross_attention_every_n_layers: int = 4
cross_attention_n_heads: int = 0
cross_attention_n_kv_heads: int = 0
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

Cross-Attention does not extend the residual stream.

Image features flow as K/V into separate CrossAttentionBlocks; the residual itself carries text only. So the seq_len cross-check skips num_tokens and just enforces seq_len >= max_text_len. The num_tokens argument is accepted for signature parity with the base method but ignored.

Parameters:

num_tokens (int)

Return type:

int

resolved_heads(model_n_heads)[source]

Resolve zero-defaults against the text backbone’s head count.

Returns (n_heads, n_kv_heads) such that the CrossAttentionBlock constructor never observes 0.

Resolution rule:

  • n_heads = self.cross_attention_n_heads or model_n_heads

  • n_kv_heads = self.cross_attention_n_kv_heads or n_heads

Parameters:

model_n_heads (int)

Return type:

tuple[int, int]

__init__(arch='cross_attention', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>, cross_attention_every_n_layers=4, cross_attention_n_heads=0, cross_attention_n_kv_heads=0)
Parameters:
Return type:

None

class kempnerforge.config.vlm.MoTConfig[source]

Bases: VLMConfig

Mixture-of-Transformers: per-modality Q/K/V/O projections + per- modality FFN at every layer; single global self-attention mixes all modality streams (Liang et al. 2024, Algorithm 1).

Image tokens are prepended to the text sequence in the residual stream (image-then-text concat order). modality_ids tags every position with its source modality; the operator routes per-token through the per-modality projection / FFN copy for that position.

The MoT-specific module alias "mot" is added to module_patterns so freeze targeting works out of the box: FreezeSpec("mot", True) freezes the per-modality main stack (transformer.layers.*) without touching the embedding / output head / final norms.

arch: str = 'mot'
mot_modalities: tuple[str, ...] = ('image', 'text')
mot_image_n_heads: int = 0
mot_image_n_kv_heads: int = 0
mot_warm_start_from_text: bool = False
mot_warm_start_path: str = ''
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

MoT prepends num_tokens image tokens to the text sequence (same residual-stream layout as Joint-Decoder).

Parameters:

num_tokens (int)

Return type:

int

resolved_image_heads(model_n_heads, model_n_kv_heads=0)[source]

Resolve zero-defaults against the text backbone’s head counts.

Returns (n_heads, n_kv_heads) such that the operator’s per-modality projection sizes are never built from 0.

Resolution rule:

  • n_heads = self.mot_image_n_heads or model_n_heads

  • n_kv_heads = self.mot_image_n_kv_heads or model_n_kv_heads or n_heads

v1 note: the global-SDPA design requires equal head counts across modalities; Transformer.__init__ asserts the resolved tuple matches the text backbone (raise on per-modality override). Field is present so a future per-modality relaxation can land without a config-shape change.

Parameters:
  • model_n_heads (int)

  • model_n_kv_heads (int)

Return type:

tuple[int, int]

__init__(arch='mot', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>, mot_modalities=('image', 'text'), mot_image_n_heads=0, mot_image_n_kv_heads=0, mot_warm_start_from_text=False, mot_warm_start_path='')
Parameters:
Return type:

None