Source code for kempnerforge.model.vision

"""Vision encoders for VLM training.

A vision encoder turns ``(B, 3, H, W)`` pixel values into a bag of
``(B, num_tokens, feature_dim)`` patch tokens that the VLM adapter maps
into the language-model embedding space.

Encoders register themselves via ``registry.register_vision_encoder``.
Currently shipped:

- ``random`` — small deterministic stub for tests and smoke configs. No
  network access required. Produces reproducible noise for a given seed.
- ``siglip2`` / ``clip`` — thin wrappers around HuggingFace
  ``AutoModel.from_pretrained``. The HF imports are deferred so the
  module is importable on machines without the ``transformers`` package,
  and failures are surfaced with a clear message.
"""

from __future__ import annotations

from typing import Any

import torch
import torch.nn as nn

from kempnerforge.config.registry import registry


[docs] class VisionEncoder(nn.Module): """Base class for vision encoders. Subclasses must set ``feature_dim`` and ``num_tokens`` before returning from ``__init__`` and implement ``forward(pixel_values)`` to produce a ``(B, num_tokens, feature_dim)`` tensor. """ feature_dim: int num_tokens: int
[docs] def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: # pragma: no cover raise NotImplementedError
[docs] class RandomVisionEncoder(VisionEncoder): """Deterministic random-token stub. The output is computed from a hash of ``pixel_values.sum()`` so the same image produces the same tokens across calls; independent of model weights so it works under FSDP2 without sharding a real encoder. Used in tests and the ``vlm_debug.toml`` smoke config. """
[docs] def __init__(self, num_tokens: int = 16, feature_dim: int = 768, seed: int = 0) -> None: super().__init__() self.num_tokens = num_tokens self.feature_dim = feature_dim self._seed = seed # Carry a trivial buffer so .to(device) / .to(dtype) have something # to move; also lets VLMWrapper confirm the module actually lives # on the target device. self.register_buffer("_anchor", torch.zeros(1, dtype=torch.float32), persistent=False)
[docs] def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: B = pixel_values.shape[0] # Derive a per-image seed from the input so the same image yields # the same tokens. Kept cheap: sum across spatial dims and cast. per_image = pixel_values.flatten(1).sum(dim=1) anchor = self.get_buffer("_anchor") out = torch.empty( B, self.num_tokens, self.feature_dim, device=pixel_values.device, dtype=anchor.dtype, ) for i in range(B): gen = torch.Generator(device="cpu") gen.manual_seed(int(self._seed) + int(per_image[i].item() * 1e6)) out[i] = torch.randn( self.num_tokens, self.feature_dim, generator=gen, ).to(device=pixel_values.device, dtype=anchor.dtype) return out
@registry.register_vision_encoder("random") def _build_random( path: str = "", num_tokens: int | None = None, feature_dim: int | None = None, **_: Any, ) -> VisionEncoder: """Builder for the test stub. ``path`` is ignored. When ``num_tokens`` / ``feature_dim`` are None the defaults are used. """ return RandomVisionEncoder( num_tokens=num_tokens if num_tokens is not None else 16, feature_dim=feature_dim if feature_dim is not None else 768, ) class _HFVisionEncoder(VisionEncoder): """Shared wrapper for HuggingFace vision encoders (SigLIP2, CLIP, ...). The HF model's vision tower produces patch tokens. CLIP-family towers prepend a CLS token at position 0; SigLIP/SigLIP2 towers do not. The ``has_cls_token`` flag tells the encoder which architecture it wraps so ``num_tokens`` reflects the actual forward output shape, and the ``strip_cls`` flag controls whether to drop position 0 before returning. """ def __init__( self, path: str, strip_cls: bool = False, has_cls_token: bool = True, ) -> None: super().__init__() if strip_cls and not has_cls_token: raise ValueError( "_HFVisionEncoder: strip_cls=True is meaningless when has_cls_token=False " "(nothing to strip). Pass strip_cls=False for SigLIP-style encoders." ) try: from transformers import AutoModel except ImportError as e: # pragma: no cover raise ImportError( "Loading HuggingFace vision encoders requires `transformers`. " "Install it or use the 'random' encoder for tests." ) from e if not path: raise ValueError("vlm.vision_encoder_path must be set for HF-backed vision encoders") model = AutoModel.from_pretrained(path) # Prefer .vision_model if present (CLIP, SigLIP family); otherwise # assume the whole loaded model is the vision tower. vision_tower = getattr(model, "vision_model", model) self.vision_tower = vision_tower self._strip_cls = strip_cls self._has_cls_token = has_cls_token # Resolve output shape from the HF config without an actual dry-run pass. cfg = getattr(vision_tower, "config", None) image_size = getattr(cfg, "image_size", 224) if cfg is not None else 224 patch_size = getattr(cfg, "patch_size", 16) if cfg is not None else 16 hidden = getattr(cfg, "hidden_size", None) if cfg is not None else None n_patches = (image_size // patch_size) ** 2 self.feature_dim = int(hidden) if hidden else -1 # -1 => resolve via dry run # Three cases for num_tokens: # - has_cls_token + strip_cls: n_patches (CLIP, dropping CLS) # - has_cls_token + not strip_cls: n_patches + 1 (CLIP, keeping CLS) # - not has_cls_token: n_patches (SigLIP/SigLIP2) if has_cls_token and not strip_cls: self.num_tokens = n_patches + 1 else: self.num_tokens = n_patches def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: out = self.vision_tower(pixel_values=pixel_values) hidden = getattr(out, "last_hidden_state", out) if self._strip_cls: hidden = hidden[:, 1:, :] return hidden @registry.register_vision_encoder("siglip2") def _build_siglip2( path: str, num_tokens: int | None = None, feature_dim: int | None = None, **_: Any, ) -> VisionEncoder: """Builder for a SigLIP2 vision tower. SigLIP2 has no CLS token.""" enc = _HFVisionEncoder(path, strip_cls=False, has_cls_token=False) if num_tokens is not None: enc.num_tokens = num_tokens if feature_dim is not None: enc.feature_dim = feature_dim return enc @registry.register_vision_encoder("clip") def _build_clip( path: str, num_tokens: int | None = None, feature_dim: int | None = None, **_: Any, ) -> VisionEncoder: """Builder for a CLIP ViT vision tower. CLIP output includes a CLS token; we strip it so ``num_tokens`` matches the number of image patches. """ enc = _HFVisionEncoder(path, strip_cls=True, has_cls_token=True) if num_tokens is not None: enc.num_tokens = num_tokens if feature_dim is not None: enc.feature_dim = feature_dim return enc