Source code for kempnerforge.model.modality
"""Modality-injection container for ``Transformer.forward``.
``ModalityContext`` groups all "inputs that flow into the existing
residual stream" plus per-token routing tags consumed by the residual
stream, so ``Transformer.forward`` stays narrow regardless of how
many architectures are active. Each VLM arch fills the fields it
needs:
- Joint-Decoder fills ``prefix_embeds + output_slice`` (image tokens
prepended to the text sequence; ``output_slice`` trims them off the
hidden state before the LM head).
- Cross-Attention fills ``image_features + image_mask`` (image K/V
flowing into separate cross-attention blocks; the residual stream
itself carries text only).
- Mixture-of-Transformers fills ``prefix_embeds + output_slice +
modality_ids``. The residual stream carries (image, text)
concatenated; ``modality_ids`` tags every position with its
modality so each layer's MoTBlock can route per-modality
projections + global self-attention.
- Pipeline-parallel middle stages fill ``inputs_embeds`` (pre-embedded
hidden state passed across stage boundaries).
Cross-arg invariants involving ``kv_caches`` (a ``Transformer.forward``
argument, not a ``ModalityContext`` field) are enforced at the top of
``Transformer.forward``, not in ``__post_init__``.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
[docs]
@dataclass(frozen=True)
class ModalityContext:
"""Modality-injection container.
Invariants enforced in ``__post_init__``:
- At most one of ``inputs_embeds``, ``prefix_embeds``,
``image_features`` may be set; they are mutually exclusive
composition routes into the residual stream.
- ``image_mask`` requires ``image_features`` to be set (a
free-standing ``image_mask`` with no features is a programming
error).
- ``modality_ids`` requires ``prefix_embeds`` or ``inputs_embeds``
to be set (routing without a residual extension is meaningless).
``output_slice`` composes with the ``tokens`` path AND with the
``inputs_embeds`` path; it is not constrained intra-context. The
cross-arg constraint (``output_slice`` vs ``kv_caches``) lives on
``Transformer.forward`` instead.
"""
inputs_embeds: torch.Tensor | None = None
prefix_embeds: torch.Tensor | None = None
output_slice: slice | None = None
image_features: torch.Tensor | None = None
image_mask: torch.Tensor | None = None
modality_ids: torch.Tensor | None = None
def __post_init__(self) -> None:
residual_routes = sum(
x is not None for x in (self.inputs_embeds, self.prefix_embeds, self.image_features)
)
if residual_routes > 1:
raise ValueError(
"ModalityContext: at most one of inputs_embeds, prefix_embeds, "
"image_features may be set (mutually exclusive residual-stream routes)"
)
if self.image_mask is not None and self.image_features is None:
raise ValueError("ModalityContext: image_mask requires image_features to be set")
if (
self.modality_ids is not None
and self.prefix_embeds is None
and self.inputs_embeds is None
):
raise ValueError(
"ModalityContext: modality_ids requires prefix_embeds OR "
"inputs_embeds to be set (routing without a residual extension is meaningless)"
)