kempnerforge.config.vlm

VLM (vision-language model) configuration.

VLMConfig carries the arch-level knobs of the vision-language model: which architecture to wire (arch), the fixed text padding length, and the freeze policy. The vision encoder and adapter are described by sibling top-level sections (VisionEncoderConfig in config/vision.py, AdapterConfig in config/adapter.py).

In TOML, [vlm] is a top-level section, parallel to [model], [vision_encoder], and [adapter]. When [vlm] is absent the job is a pure text run.

Architecture is a discriminated union on the arch field:

  • "joint_decoder" image tokens prepended to the text sequence.

  • "cross_attention" image K/V flows in via separate cross-attention blocks at a configurable cadence.

  • "mot" Mixture-of-Transformers: per-modality Q/K/V/O + per- modality FFN at every layer, single global self-attention.

  • "moma" Mixture of Modality-Aware Experts: shared Q/K/V/O + per-modality MoE FFN groups at every layer. Tokens are routed deterministically by modality (level 1) then by a learned expert-choice + Sigmoid router within their modality group (level 2). Lin et al. 2024 (arXiv:2407.21770).

Each arch gets its own VLMConfig subclass, registered via registry.register_vlm_config. The TOML loader dispatches on arch to instantiate the right subclass; programmatic callers use VLMConfig.for_arch(arch_name, **fields).

FreezeSpec / FreezeStage are consumed by kempnerforge/training/freeze.py.

Classes

CrossAttentionConfig

Cross-Attention: image K/V flows into separate cross-attention blocks inserted at a configurable cadence.

FreezeSpec

A single freeze directive.

FreezeStage

A freeze directive that applies from start_step onward.

JointDecoderConfig

Joint-Decoder: image tokens prepended to the text sequence.

MoMaConfig

Mixture of Modality-Aware Experts (MoMa): shared self-attention + per-modality MoE FFN groups (Lin et al. 2024, arXiv:2407.21770).

MoTConfig

Mixture-of-Transformers: per-modality Q/K/V/O projections + per- modality FFN at every layer; single global self-attention mixes all modality streams (Liang et al. 2024, Algorithm 1).

VLMConfig

Base VLM configuration.

class kempnerforge.config.vlm.FreezeSpec[source]

Bases: object

A single freeze directive.

module is an alias (key in a pattern map such as DEFAULT_MODULE_PATTERNS) or a raw fnmatch pattern matching fully-qualified parameter names.

module: str
frozen: bool = True
__init__(module, frozen=True)
Parameters:
Return type:

None

class kempnerforge.config.vlm.FreezeStage[source]

Bases: object

A freeze directive that applies from start_step onward.

Used for staged training recipes where the trainable subset changes across training phases. The list of stages on VLMConfig is expected to be in strictly monotonic start_step order.

start_step: int
specs: tuple[FreezeSpec, ...]
__init__(start_step, specs)
Parameters:
Return type:

None

class kempnerforge.config.vlm.VLMConfig[source]

Bases: object

Base VLM configuration.

Subclasses register themselves via @registry.register_vlm_config and override the arch field’s default. Use VLMConfig.for_arch(arch_name, **fields) to construct programmatically; the TOML loader dispatches on arch automatically.

Field summary (full per-field docs are picked up from autodoc):

  • arch VLM architecture discriminator. Subclasses set this via field default; direct construction with an arch name not backed by a registered subclass raises.

  • max_text_len fixed text padding length used by VLMCollator. Enforces rank-consistent batches under FSDP2.

  • freeze static freeze specs applied once at build time.

  • freeze_schedule step-boundary freeze transitions.

  • module_patterns map of module alias ("transformer", "vision_encoder", "adapter", plus arch-specific additions) to fnmatch pattern list.

arch: str = 'joint_decoder'
max_text_len: int = 512
freeze: list[FreezeSpec]
freeze_schedule: list[FreezeStage]
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

Number of image tokens this arch places in the residual stream.

Used by JobConfig to validate that model.max_seq_len and train.seq_len are large enough to fit residual_stream_image_tokens + max_text_len along the attention sequence dimension.

  • Joint-Decoder / MoT: num_tokens (image tokens prepended to text).

  • Cross-Attention: 0 (residual stream is text-only; image features flow side-channel into CA blocks).

Parameters:

num_tokens (int) – The vision encoder’s resolved num_tokens. Pass 0 when it is not known yet (the “infer at build time” sentinel); cross-checks that depend on a concrete value will skip and re-run at build time.

Return type:

int

classmethod for_arch(arch, **kwargs)[source]

Resolve arch to a registered subclass and instantiate.

Raises:
  • ValueErrorarch is not registered.

  • NotImplementedErrorarch is reserved (in _RESERVED_ARCHS) matches loader semantics so the error type is independent of construction site.

Parameters:
Return type:

VLMConfig

Example

>>> cfg = VLMConfig.for_arch(
...     "cross_attention",
...     max_text_len=2048,
...     cross_attention_every_n_layers=4,
... )
__init__(arch='joint_decoder', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>)
Parameters:
Return type:

None

class kempnerforge.config.vlm.JointDecoderConfig[source]

Bases: VLMConfig

Joint-Decoder: image tokens prepended to the text sequence.

No additional fields beyond VLMConfig. The arch is wired through VLMWrapper + ModalityContext.prefix_embeds + output_slice.

arch: str = 'joint_decoder'
__init__(arch='joint_decoder', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>)
Parameters:
Return type:

None

class kempnerforge.config.vlm.CrossAttentionConfig[source]

Bases: VLMConfig

Cross-Attention: image K/V flows into separate cross-attention blocks inserted at a configurable cadence.

The CA-specific module alias "cross_attention" is added to module_patterns so freeze targeting works out of the box.

arch: str = 'cross_attention'
cross_attention_every_n_layers: int = 4
cross_attention_n_heads: int = 0
cross_attention_n_kv_heads: int = 0
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

Cross-Attention does not extend the residual stream.

Image features flow as K/V into separate CrossAttentionBlocks; the residual itself carries text only. So the seq_len cross-check skips num_tokens and just enforces seq_len >= max_text_len. The num_tokens argument is accepted for signature parity with the base method but ignored.

Parameters:

num_tokens (int)

Return type:

int

resolved_heads(model_n_heads)[source]

Resolve zero-defaults against the text backbone’s head count.

Returns (n_heads, n_kv_heads) such that the CrossAttentionBlock constructor never observes 0.

Resolution rule:

  • n_heads = self.cross_attention_n_heads or model_n_heads

  • n_kv_heads = self.cross_attention_n_kv_heads or n_heads

Parameters:

model_n_heads (int)

Return type:

tuple[int, int]

__init__(arch='cross_attention', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>, cross_attention_every_n_layers=4, cross_attention_n_heads=0, cross_attention_n_kv_heads=0)
Parameters:
Return type:

None

class kempnerforge.config.vlm.MoTConfig[source]

Bases: VLMConfig

Mixture-of-Transformers: per-modality Q/K/V/O projections + per- modality FFN at every layer; single global self-attention mixes all modality streams (Liang et al. 2024, Algorithm 1).

Image tokens are prepended to the text sequence in the residual stream (image-then-text concat order). modality_ids tags every position with its source modality; the operator routes per-token through the per-modality projection / FFN copy for that position.

The MoT-specific module alias "mot" is added to module_patterns so freeze targeting works out of the box: FreezeSpec("mot", True) freezes the per-modality main stack (transformer.layers.*) without touching the embedding / output head / final norms.

arch: str = 'mot'
mot_modalities: tuple[str, ...] = ('image', 'text')
mot_image_n_heads: int = 0
mot_image_n_kv_heads: int = 0
mot_warm_start_from_text: bool = False
mot_warm_start_path: str = ''
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

MoT prepends num_tokens image tokens to the text sequence (same residual-stream layout as Joint-Decoder).

Parameters:

num_tokens (int)

Return type:

int

resolved_image_heads(model_n_heads, model_n_kv_heads=0)[source]

Resolve zero-defaults against the text backbone’s head counts.

Returns (n_heads, n_kv_heads) such that the operator’s per-modality projection sizes are never built from 0.

Resolution rule:

  • n_heads = self.mot_image_n_heads or model_n_heads

  • n_kv_heads = self.mot_image_n_kv_heads or model_n_kv_heads or n_heads

v1 note: the global-SDPA design requires equal head counts across modalities; Transformer.__init__ asserts the resolved tuple matches the text backbone (raise on per-modality override). Field is present so a future per-modality relaxation can land without a config-shape change.

Parameters:
  • model_n_heads (int)

  • model_n_kv_heads (int)

Return type:

tuple[int, int]

__init__(arch='mot', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>, mot_modalities=('image', 'text'), mot_image_n_heads=0, mot_image_n_kv_heads=0, mot_warm_start_from_text=False, mot_warm_start_path='')
Parameters:
Return type:

None

class kempnerforge.config.vlm.MoMaConfig[source]

Bases: VLMConfig

Mixture of Modality-Aware Experts (MoMa): shared self-attention + per-modality MoE FFN groups (Lin et al. 2024, arXiv:2407.21770).

Each transformer layer is a pre-norm block with:

  • Standard Attention (one set of Q/K/V/O across modalities) running a single global SDPA over the concatenated image+text sequence.

  • A MoMaFFN that routes tokens in two stages:

    1. Deterministic by modality (level 1): token’s modality_ids value selects which modality expert group processes it.

    2. Learned expert-choice + Sigmoid (level 2): within the modality group, each expert independently picks its top-k tokens by sigmoid score (with optional Gumbel-Sigmoid noise during training; paper Eq. 5). Token output is the sum of selected experts’ outputs weighted by their sigmoid scores.

Image tokens are prepended to the text sequence (same residual layout as Joint-Decoder and MoT). modality_ids tags every position; the FFN uses these tags for scatter/gather dispatch (works for arbitrary interleaved layouts, not just image-prefix).

Differs from "mot": MoT has per-modality Q/K/V/O and per-modality FFN. MoMa has shared Q/K/V/O and per-modality MoE FFN groups (multiple experts per modality, learned routing within each group).

Inference note: expert-choice routing is non-causal (each expert’s top-k depends on all tokens in the batch). v1 supports training only; autoregressive generation requires auxiliary routers (paper §2.4), deferred to a follow-up.

The MoMa-specific module alias "moma" is added to module_patterns so freeze targeting works out of the box: FreezeSpec("moma", True) freezes the per-modality MoE stack (transformer.layers.*) without touching the embedding, output head, or final norm.

arch: str = 'moma'
moma_modalities: tuple[str, ...] = ('image', 'text')
moma_experts_per_modality: dict[str, int]
moma_capacity_factor: float = 0.0
moma_gumbel_noise: bool = True
module_patterns: dict[str, list[str]]
residual_stream_image_tokens(num_tokens)[source]

MoMa prepends num_tokens image tokens to the text sequence (same residual-stream layout as Joint-Decoder).

Parameters:

num_tokens (int)

Return type:

int

effective_capacity_factor(modality)[source]

Resolve the per-expert capacity factor for modality.

Paper default (moma_capacity_factor == 0): return 1 / |E^M| so each expert sees the average load per modality (perfect balance under expert-choice routing). Explicit positive values pass through unchanged.

Parameters:

modality (str)

Return type:

float

__init__(arch='moma', max_text_len=512, freeze=<factory>, freeze_schedule=<factory>, module_patterns=<factory>, moma_modalities=('image', 'text'), moma_experts_per_modality=<factory>, moma_capacity_factor=0.0, moma_gumbel_noise=True)
Parameters:
Return type:

None