Source code for kempnerforge.data.vlm_dataset

"""VLM dataset and collator (Joint-Decoder).

``HuggingFaceVLMDataset`` wraps a HuggingFace image-text dataset and
produces the ``VLMSample`` contract:

- ``pixel_values``: ``(3, H, W)`` float tensor, resized to ``image_size``
  and normalized with the provided mean/std.
- ``input_ids``: ``(T,)`` int64 tensor, right-padded to ``max_text_len``.
- ``labels``: ``(T,)`` int64 tensor matching ``input_ids`` with ``-100``
  on padding positions and (optionally) on prompt positions when
  ``prompt_field`` is set.

``VLMCollator`` stacks a list of samples into a batch. All batches are
padded to the same fixed ``max_text_len`` regardless of batch content so
different ranks see identical tensor shapes (no NCCL desync under
FSDP2). The collator also emits ``image_positions: (B,)`` zeros; this
slot is reserved for a future multi-image extension and is unused by
the Joint-Decoder wrapper today.

The HF ``datasets`` and ``transformers`` packages are imported lazily
so this module is safe to import without them (e.g. in unit tests that
don't exercise the dataset path).
"""

from __future__ import annotations

import logging
from typing import Any

import torch
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)

# SigLIP-style normalization defaults; pass explicit mean/std if your
# vision encoder was trained with CLIP normalization.
DEFAULT_IMAGE_MEAN = (0.5, 0.5, 0.5)
DEFAULT_IMAGE_STD = (0.5, 0.5, 0.5)


def _pil_to_tensor(
    img: Any,
    image_size: int,
    mean: tuple[float, float, float],
    std: tuple[float, float, float],
) -> torch.Tensor:
    """Resize, convert to (3, H, W) float tensor, and normalize.

    Accepts a PIL ``Image``. Converts to RGB if needed so grayscale /
    RGBA inputs do not drop into the encoder with the wrong channel
    count.
    """
    from PIL import Image

    if not isinstance(img, Image.Image):
        raise TypeError(
            f"Expected a PIL.Image, got {type(img).__name__}. "
            "If the dataset returns raw bytes, decode with PIL first."
        )
    img = img.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
    import numpy as np

    arr = np.asarray(img, dtype=np.float32) / 255.0  # (H, W, 3)
    t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()  # (3, H, W)
    mean_t = torch.tensor(mean, dtype=t.dtype).view(3, 1, 1)
    std_t = torch.tensor(std, dtype=t.dtype).view(3, 1, 1)
    return (t - mean_t) / std_t


def _tokenize_and_mask(
    tokenizer: Any,
    text: str,
    max_text_len: int,
    prompt: str | None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Tokenize and build right-padded input_ids + labels.

    When ``prompt`` is provided, the prompt portion of ``labels`` is
    masked with ``-100`` (loss does not backpropagate through prompt
    tokens). Padding positions in both ``input_ids`` and ``labels`` are
    handled via ``ignore_index=-100`` on the loss.

    BPE and SentencePiece tokenizers are NOT prefix-preserving: in
    general ``tokenize(prompt) + tokenize(text)`` differs from
    ``tokenize(prompt + text)`` at the boundary (tokens can merge or
    split). To guarantee the mask lines up with the prompt boundary we
    tokenize prompt and text independently, then concatenate the id
    lists. The mask length is ``len(prompt_ids)``, so masking
    ``labels[:prompt_len]`` cannot leak a prompt token into supervision
    or erase the first target token.
    """
    if prompt is not None:
        prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
        text_ids = tokenizer(text, add_special_tokens=False)["input_ids"]
        full_ids = list(prompt_ids) + list(text_ids)
        prompt_len = min(len(prompt_ids), max_text_len)
    else:
        full_ids = list(tokenizer(text, add_special_tokens=False)["input_ids"])
        prompt_len = 0

    full_ids = full_ids[:max_text_len]
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0

    n = len(full_ids)
    input_ids = torch.full((max_text_len,), pad_id, dtype=torch.long)
    labels = torch.full((max_text_len,), -100, dtype=torch.long)
    if n > 0:
        ids_tensor = torch.tensor(full_ids, dtype=torch.long)
        input_ids[:n] = ids_tensor
        labels[:n] = ids_tensor
        if prompt_len > 0:
            labels[:prompt_len] = -100
    return input_ids, labels


[docs] class HuggingFaceVLMDataset(Dataset): """Map-style HF image-text dataset for Joint-Decoder training. Args: dataset_name: HF dataset name (e.g. ``"sayakpaul/coco-30-val-2014"``) or a local directory written by ``datasets.save_to_disk``. split: Dataset split. image_field: Column name for the PIL image. text_field: Column name for the caption / target text. tokenizer_path: HF tokenizer id or local path. max_text_len: Fixed-length pad target; passed to the collator. prompt_field: Optional column name for a prompt that should NOT receive loss (e.g. the instruction in an instruction-tuned dataset). Prompt tokens get ``labels=-100``. image_size: Target square image size. Default 224. image_mean / image_std: Normalization stats. Defaults match SigLIP's ``(0.5, 0.5, 0.5)``. dataset_config: HF dataset config name, if required. """
[docs] def __init__( self, dataset_name: str, split: str, image_field: str, text_field: str, tokenizer_path: str, max_text_len: int, prompt_field: str | None = None, image_size: int = 224, image_mean: tuple[float, float, float] = DEFAULT_IMAGE_MEAN, image_std: tuple[float, float, float] = DEFAULT_IMAGE_STD, dataset_config: str | None = None, ) -> None: import os from datasets import Dataset as HFDataset from datasets import load_dataset, load_from_disk from transformers import AutoTokenizer # When dataset_name points at an existing directory on disk (absolute # or relative path), prefer load_from_disk. Otherwise treat it as an # HF Hub id. This keeps the TOML contract unchanged while letting # users preprocess a dataset once and load it without network access. is_local = os.path.isdir(dataset_name) if is_local: loaded = load_from_disk(dataset_name) # load_from_disk returns a Dataset if the dir holds a single # split, or a DatasetDict of splits; select by split name in # the latter case. ds = loaded[split] if not isinstance(loaded, HFDataset) else loaded else: ds = load_dataset(dataset_name, dataset_config, split=split) if not isinstance(ds, HFDataset): raise TypeError( f"Expected a map-style datasets.Dataset, got {type(ds).__name__}. " "Streaming or split-dict outputs are not supported by " "HuggingFaceVLMDataset (use `split=` to resolve a single split)." ) self._ds: HFDataset = ds self._image_field = image_field self._text_field = text_field self._prompt_field = prompt_field self._image_size = image_size self._image_mean = image_mean self._image_std = image_std self._max_text_len = max_text_len self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) logger.info( f"HuggingFaceVLMDataset: {dataset_name} [{split}], {len(self._ds):,} samples, " f"image_size={image_size}, max_text_len={max_text_len}" )
def __len__(self) -> int: return len(self._ds) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: row = self._ds[idx] pixels = _pil_to_tensor( row[self._image_field], self._image_size, self._image_mean, self._image_std ) prompt_val = row[self._prompt_field] if self._prompt_field is not None else None text_val = row[self._text_field] if not isinstance(text_val, str): raise TypeError( f"Dataset field {self._text_field!r} must be str, got {type(text_val).__name__}" ) prompt_str = prompt_val if isinstance(prompt_val, str) else None input_ids, labels = _tokenize_and_mask( self._tokenizer, text_val, self._max_text_len, prompt_str ) return {"pixel_values": pixels, "input_ids": input_ids, "labels": labels}
[docs] class VLMCollator: """Stack VLM samples into a fixed-length batch. Output keys: - ``pixel_values``: ``(B, 3, H, W)``. - ``input_ids``: ``(B, max_text_len)`` int64. - ``labels``: ``(B, max_text_len)`` int64 with ``-100`` on pad. - ``image_positions``: ``(B,)`` long tensor. Reserved slot for multi-image extensions; currently all zeros (single image per example placed at sequence position 0). Padding is always to ``max_text_len``, never batch-max, so ranks always see identical tensor shapes under FSDP2. """
[docs] def __init__(self, pad_id: int, max_text_len: int) -> None: if max_text_len <= 0: raise ValueError("max_text_len must be positive") self.pad_id = pad_id self.max_text_len = max_text_len
def __call__(self, samples: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: if not samples: raise ValueError("VLMCollator received an empty batch") B = len(samples) pixels = torch.stack([s["pixel_values"] for s in samples], dim=0) input_ids = torch.full((B, self.max_text_len), self.pad_id, dtype=torch.long) labels = torch.full((B, self.max_text_len), -100, dtype=torch.long) for i, s in enumerate(samples): ids = s["input_ids"] lbl = s["labels"] n = min(ids.shape[0], self.max_text_len) input_ids[i, :n] = ids[:n] labels[i, :n] = lbl[:n] image_positions = torch.zeros(B, dtype=torch.long) return { "pixel_values": pixels, "input_ids": input_ids, "labels": labels, "image_positions": image_positions, }