"""Vision-to-LLM adapter modules (the "connector").
The adapter projects vision features (shape ``(B, num_tokens, feature_dim)``)
into the LLM embedding space (shape ``(B, out_tokens, model.dim)``). It sits
between the vision encoder and the transformer in ``VLMWrapper``.
Two families:
- **Projection adapters** keep the token count (``out_tokens == num_tokens``):
``mlp_2layer`` (default, the canonical LLaVA-family 2-layer MLP) and
``linear`` (single ``nn.Linear``, an ablation baseline).
- **Pooling adapters** reduce the token count by pooling the square patch grid
before projecting: ``avgpool`` (window-average, the cheapest reducer) and
``attentional_pool`` (Molmo2-style per-window multi-head attention with the
window mean as query). Pooling is what makes many-frame video fit the
sequence budget: a 27×27 SigLIP grid (729 tokens) pools to 81 tokens at a
3×3 window.
Every adapter is a ``VisionAdapter`` exposing ``output_num_tokens(n_in)`` so the
build path can size the residual stream (and MoT's positional split) without a
dry-run forward. Adapters register themselves under the ``adapter`` registry
category.
"""
from __future__ import annotations
import math
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from kempnerforge.config.registry import registry
_ADAPTER_ACTIVATIONS: dict[str, type[nn.Module]] = {
"gelu": nn.GELU,
"silu": nn.SiLU,
"relu": nn.ReLU,
}
# Registry keys of adapters that pool the patch grid (reduce token count). The
# config layer (``AdapterConfig.output_num_tokens``) consults this to predict
# the post-adapter token count without building the module. Keep in sync with
# the registered pooling builders below.
POOLING_ADAPTER_TYPES: tuple[str, ...] = ("avgpool", "attentional_pool")
# Pooling adapters whose ``forward`` requires the patch grid be divisible by the
# window (no ragged edge windows). Their token count must enforce the same so a
# ragged config is rejected at config/build time, not at the first training step.
DIVISIBLE_ONLY_POOL_TYPES: tuple[str, ...] = ("attentional_pool",)
[docs]
def pooled_token_count(
num_input_tokens: int, window: int, *, require_divisible: bool = False
) -> int:
"""Token count out of a ``window×window`` pool over a square patch grid.
A vision encoder emits ``num_input_tokens`` patch tokens laid out on a
square ``grid × grid`` map (``grid = sqrt(num_input_tokens)``). Pooling with
a ``window × window`` kernel and ceil edges yields ``ceil(grid/window) ** 2``
tokens; edge windows that do not fill the kernel pool only the patches they
cover (Molmo2 §A: "the bottom and far-right image patches are pooled with a
reduced number of patches").
Connectors that cannot pool ragged edges (``require_divisible=True``, e.g.
``attentional_pool``) raise when ``grid`` is not divisible by ``window``, so a
ragged config is rejected at config/build time rather than deterministically
failing in ``forward`` at the first step.
This is the single source of truth for the post-pool count: it must equal
the pooling adapters' actual ``forward`` output length, because the build
path uses it to size MoT's positional split.
"""
if window <= 0:
raise ValueError(f"pool window must be positive (got {window})")
if num_input_tokens <= 0:
raise ValueError(f"num_input_tokens must be positive (got {num_input_tokens})")
grid = _grid_side(num_input_tokens)
if require_divisible and grid % window != 0:
raise ValueError(
f"this pooling connector requires the patch grid ({grid}x{grid}) be "
f"divisible by the pool window ({window}); got a ragged grid "
f"(num_tokens={num_input_tokens}). Use avgpool for ragged grids, or pick "
"a divisible window."
)
per_side = math.ceil(grid / window)
return per_side * per_side
def _grid_side(num_tokens: int) -> int:
"""Side length of the square patch grid, or raise if not a perfect square."""
grid = math.isqrt(num_tokens)
if grid * grid != num_tokens:
raise ValueError(
f"pooling requires a square patch grid, but num_tokens={num_tokens} is "
"not a perfect square. Use a vision encoder that strips any CLS token so "
"the patch tokens form a square grid."
)
return grid
[docs]
class VisionAdapter(nn.Module):
"""Base class for vision→LLM adapters (the connector).
Contract: ``forward`` maps ``(B, N, in_dim) -> (B, M, out_dim)`` where
``M == output_num_tokens(N)``. Projection adapters keep ``M == N``; pooling
adapters reduce it. ``output_num_tokens`` lets the build path size the
residual stream and MoT's positional split without a dry-run forward, and
must agree exactly with the forward output length.
"""
[docs]
def output_num_tokens(self, num_input_tokens: int) -> int:
"""Tokens emitted per image given ``num_input_tokens`` patch tokens in.
Identity by default (projection adapters); pooling adapters override.
"""
return num_input_tokens
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover
raise NotImplementedError
[docs]
class MLP2LayerAdapter(VisionAdapter):
"""2-layer MLP from image-feature dim to LLM embedding dim.
Architecture: ``Linear(in_dim, hidden) -> activation -> Linear(hidden, out_dim)``.
``hidden_dim=None`` defaults to ``out_dim``. Keeps the token count.
``reset_parameters`` is provided so callers that materialize adapters
from meta can re-initialize weights with the standard Linear defaults.
"""
[docs]
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dim: int | None = None,
activation: str = "gelu",
) -> None:
super().__init__()
if in_dim <= 0 or out_dim <= 0:
raise ValueError("MLP2LayerAdapter in_dim and out_dim must be positive")
if activation not in _ADAPTER_ACTIVATIONS:
raise ValueError(
f"Unknown adapter activation: {activation!r}. Options: {list(_ADAPTER_ACTIVATIONS)}"
)
hidden = hidden_dim if hidden_dim and hidden_dim > 0 else out_dim
self.proj1 = nn.Linear(in_dim, hidden, bias=True)
self.act = _ADAPTER_ACTIVATIONS[activation]()
self.proj2 = nn.Linear(hidden, out_dim, bias=True)
[docs]
def reset_parameters(self) -> None:
"""Re-run ``nn.Linear`` default init on both projections.
Used after ``to_empty(device=...)`` on a meta-device build.
"""
self.proj1.reset_parameters()
self.proj2.reset_parameters()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj2(self.act(self.proj1(x)))
[docs]
class LinearAdapter(VisionAdapter):
"""Single ``nn.Linear`` from image-feature dim to LLM embedding dim.
No activation, no hidden layer. Keeps the token count. Useful as an
ablation baseline against ``MLP2LayerAdapter``.
"""
[docs]
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
if in_dim <= 0 or out_dim <= 0:
raise ValueError("LinearAdapter in_dim and out_dim must be positive")
self.proj = nn.Linear(in_dim, out_dim, bias=True)
[docs]
def reset_parameters(self) -> None:
self.proj.reset_parameters()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
[docs]
class AvgPoolAdapter(VisionAdapter):
"""Average-pool a square patch grid by a window, then project.
``(B, N, in_dim)`` patch tokens (``N == grid**2``) are averaged over
``window × window`` spatial windows (ceil edges; partial edge windows
average only the real patches they cover), giving ``(B, M, in_dim)`` with
``M == ceil(grid/window)**2``, then a ``Linear`` maps ``in_dim -> out_dim``.
The cheapest token-count reducer (LLaVA-NeXT / sibling-repo style). ``window``
is overridable per ``forward`` call so one connector can pool images (e.g.
2×2) and video frames (3×3) with the same projection weights.
"""
[docs]
def __init__(self, in_dim: int, out_dim: int, pool_window: int = 2) -> None:
super().__init__()
if in_dim <= 0 or out_dim <= 0:
raise ValueError("AvgPoolAdapter in_dim and out_dim must be positive")
if pool_window <= 0:
raise ValueError(f"AvgPoolAdapter pool_window must be positive (got {pool_window})")
self.in_dim = in_dim
self.out_dim = out_dim
self.pool_window = pool_window
self.proj = nn.Linear(in_dim, out_dim, bias=True)
[docs]
def reset_parameters(self) -> None:
self.proj.reset_parameters()
[docs]
def output_num_tokens(self, num_input_tokens: int) -> int:
return pooled_token_count(num_input_tokens, self.pool_window)
[docs]
def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor:
w = pool_window if pool_window is not None else self.pool_window
if w <= 0:
raise ValueError(f"pool_window must be positive (got {w})")
b, n, c = x.shape
grid = _grid_side(n)
per = math.ceil(grid / w)
padded = per * w
x = x.view(b, grid, grid, c)
if padded != grid:
pad = padded - grid
# F.pad pads from the last dim backward: (C:0,0)(W:0,pad)(H:0,pad).
x = F.pad(x, (0, 0, 0, pad, 0, pad))
mask = torch.ones(b, grid, grid, 1, dtype=x.dtype, device=x.device)
mask = F.pad(mask, (0, 0, 0, pad, 0, pad))
else:
mask = torch.ones(b, padded, padded, 1, dtype=x.dtype, device=x.device)
# Group into windows and average over real (unpadded) cells only.
sums = x.view(b, per, w, per, w, c).sum(dim=(2, 4)) # (B, per, per, C)
counts = mask.view(b, per, w, per, w, 1).sum(dim=(2, 4)).clamp_(min=1) # (B, per, per, 1)
pooled = (sums / counts).reshape(b, per * per, c)
return self.proj(pooled)
[docs]
class AttentionalPoolAdapter(VisionAdapter):
"""Attentional pooling connector (Molmo2 §3.1).
For each ``window × window`` patch window, a multi-head attention layer
pools the window's patches into one vector, using the **mean of the window's
patches as the query** and the patches themselves as keys/values; the result
is projected ``in_dim -> out_dim``. Output length is ``ceil(grid/window)**2``.
``window`` is overridable per ``forward`` call (shared params across image
2×2 and video 3×3 pooling, per the paper). v1 requires the grid be divisible
by the window (no ragged edge windows); ragged attentional pooling is a
follow-up.
"""
[docs]
def __init__(
self, in_dim: int, out_dim: int, pool_window: int = 2, pool_heads: int = 16
) -> None:
super().__init__()
if in_dim <= 0 or out_dim <= 0:
raise ValueError("AttentionalPoolAdapter in_dim and out_dim must be positive")
if pool_window <= 0:
raise ValueError(
f"AttentionalPoolAdapter pool_window must be positive (got {pool_window})"
)
if pool_heads <= 0:
raise ValueError(
f"AttentionalPoolAdapter pool_heads must be positive (got {pool_heads})"
)
if in_dim % pool_heads != 0:
raise ValueError(
f"AttentionalPoolAdapter in_dim ({in_dim}) must be divisible by "
f"pool_heads ({pool_heads})"
)
self.in_dim = in_dim
self.out_dim = out_dim
self.pool_window = pool_window
self.pool_heads = pool_heads
self.head_dim = in_dim // pool_heads
self.q_proj = nn.Linear(in_dim, in_dim, bias=True)
self.k_proj = nn.Linear(in_dim, in_dim, bias=True)
self.v_proj = nn.Linear(in_dim, in_dim, bias=True)
self.o_proj = nn.Linear(in_dim, in_dim, bias=True)
self.out_proj = nn.Linear(in_dim, out_dim, bias=True)
[docs]
def reset_parameters(self) -> None:
for layer in (self.q_proj, self.k_proj, self.v_proj, self.o_proj, self.out_proj):
layer.reset_parameters()
[docs]
def output_num_tokens(self, num_input_tokens: int) -> int:
# require_divisible mirrors forward()'s ragged-grid rejection so a bad
# config fails at build / seq-len-check time, not at the first step.
return pooled_token_count(num_input_tokens, self.pool_window, require_divisible=True)
[docs]
def forward(self, x: torch.Tensor, pool_window: int | None = None) -> torch.Tensor:
w = pool_window if pool_window is not None else self.pool_window
if w <= 0:
raise ValueError(f"pool_window must be positive (got {w})")
b, n, c = x.shape
grid = _grid_side(n)
if grid % w != 0:
raise ValueError(
f"attentional_pool v1 requires the patch grid ({grid}x{grid}) be divisible "
f"by the pool window ({w}); ragged edge windows are not yet supported. "
"Use avgpool for ragged grids, or pick a divisible window."
)
per = grid // w
k_win = w * w
# (B, grid, grid, C) -> windows (B*per*per, w*w, C): each window's patches contiguous.
windows = (
x.view(b, per, w, per, w, c).permute(0, 1, 3, 2, 4, 5).reshape(b * per * per, k_win, c)
)
m = windows.shape[0]
query = windows.mean(dim=1, keepdim=True) # (M, 1, C) — window mean as query
q = self.q_proj(query).view(m, 1, self.pool_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(windows).view(m, k_win, self.pool_heads, self.head_dim).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v) # (M, H, 1, head_dim)
attn = attn.transpose(1, 2).reshape(m, c) # (M, C)
pooled = self.o_proj(attn).view(b, per * per, c)
return self.out_proj(pooled)
@registry.register_adapter("mlp_2layer")
def _build_mlp_2layer(
in_dim: int,
out_dim: int,
hidden_dim: int | None = None,
activation: str = "gelu",
**_: Any,
) -> VisionAdapter:
return MLP2LayerAdapter(
in_dim=in_dim,
out_dim=out_dim,
hidden_dim=hidden_dim,
activation=activation,
)
@registry.register_adapter("linear")
def _build_linear(
in_dim: int,
out_dim: int,
**_: Any,
) -> VisionAdapter:
return LinearAdapter(in_dim=in_dim, out_dim=out_dim)
@registry.register_adapter("avgpool")
def _build_avgpool(
in_dim: int,
out_dim: int,
pool_window: int = 2,
**_: Any,
) -> VisionAdapter:
return AvgPoolAdapter(in_dim=in_dim, out_dim=out_dim, pool_window=pool_window)
@registry.register_adapter("attentional_pool")
def _build_attentional_pool(
in_dim: int,
out_dim: int,
pool_window: int = 2,
pool_heads: int = 16,
**_: Any,
) -> VisionAdapter:
return AttentionalPoolAdapter(
in_dim=in_dim, out_dim=out_dim, pool_window=pool_window, pool_heads=pool_heads
)
[docs]
def build_adapter(adapter_config, in_dim: int, out_dim: int) -> VisionAdapter:
"""Dispatch to the registered adapter builder.
Args:
adapter_config: ``AdapterConfig`` (or compatible object exposing
``type`` and ``extra_kwargs()``).
in_dim: Source feature dim (the vision encoder's ``feature_dim``).
out_dim: Target embedding dim (the transformer's ``dim``).
Returns:
A ``VisionAdapter`` with signature ``(B, N, in_dim) -> (B, M, out_dim)``,
where ``M == adapter.output_num_tokens(N)``.
"""
builder = registry.get_adapter(adapter_config.type)
return builder(in_dim=in_dim, out_dim=out_dim, **adapter_config.extra_kwargs())