Source code for kempnerforge.data.dataset

"""Dataset implementations for KempnerForge.

Three dataset types:
  - MemoryMappedDataset: Pre-tokenized numpy files with zero-copy mmap access.
  - HuggingFaceDataset: HuggingFace datasets with eager loading and sequence packing.
  - StreamingHuggingFaceDataset: Streaming HuggingFace datasets for very large corpora
    that don't fit in memory. On-the-fly tokenization with sequence packing.

All implement a stateful interface (state_dict / load_state_dict) for
resumption after checkpoint loads.
"""

from __future__ import annotations

import bisect
import contextlib
import logging
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

logger = logging.getLogger(__name__)


def _compute_packed_output(tokens: np.ndarray, eos_token_id: int) -> dict[str, torch.Tensor]:
    """Compute input_ids, labels, and doc_ids for a packed token sequence.

    Detects document boundaries from EOS tokens. Each document includes its
    trailing EOS. Cross-document label positions are masked with -100 so the
    loss function can ignore them (predictions at the boundary between two
    documents are meaningless).

    Args:
        tokens: Token array of shape ``(seq_len + 1,)`` (extra token for label offset).
        eos_token_id: Token ID that marks document boundaries.

    Returns:
        Dict with ``input_ids`` (seq_len,), ``labels`` (seq_len, with -100 at
        cross-document boundaries), and ``doc_ids`` (seq_len, integer document
        assignment per input token for attention masking).
    """
    # Assign a document ID to each token: increment after every EOS
    doc_ids = np.zeros(len(tokens), dtype=np.int64)
    doc_id = 0
    for i in range(len(tokens)):
        doc_ids[i] = doc_id
        if tokens[i] == eos_token_id:
            doc_id += 1

    token_tensor = torch.from_numpy(tokens.copy()).long()
    doc_id_tensor = torch.from_numpy(doc_ids.copy())

    input_ids = token_tensor[:-1]
    labels = token_tensor[1:].clone()
    input_doc_ids = doc_id_tensor[:-1]
    label_doc_ids = doc_id_tensor[1:]

    # Mask labels at cross-document boundaries (first token of a new document
    # should not be predicted from the last token of the previous document)
    cross_boundary = input_doc_ids != label_doc_ids
    labels[cross_boundary] = -100

    return {"input_ids": input_ids, "labels": labels, "doc_ids": input_doc_ids}


[docs] class MemoryMappedDataset(Dataset): """Pre-tokenized dataset backed by memory-mapped numpy files. Expects .npy files containing 1D arrays of uint16/uint32 token IDs that have been pre-packed into fixed-length sequences. File layout: each file stores a flat array of tokens. The dataset splits them into non-overlapping chunks of ``seq_len`` tokens. Multiple files are concatenated logically. Args: data_dir: Directory containing .npy token files. seq_len: Sequence length (number of tokens per sample). file_pattern: Glob pattern for data files. """
[docs] def __init__( self, data_dir: str, seq_len: int, file_pattern: str = "*.npy", pack_sequences: bool = False, eos_token_id: int | None = None, ) -> None: self.seq_len = seq_len self._pack_sequences = pack_sequences self._eos_token_id = eos_token_id if pack_sequences and eos_token_id is None: raise ValueError("eos_token_id is required when pack_sequences=True") # Discover and sort data files for deterministic ordering data_path = Path(data_dir) self._files = sorted(data_path.glob(file_pattern)) if not self._files: raise FileNotFoundError(f"No files matching {file_pattern!r} in {data_dir}") # Detect file format from extension self._is_bin = self._files[0].suffix == ".bin" # Memory-map all files and compute cumulative offsets self._mmaps: list[np.ndarray] = [] self._cumulative_samples: list[int] = [0] total_tokens = 0 # If any open fails partway, close the ones we already opened so they # don't leak via the exception traceback (pytest, logger.exception, # post-mortem debuggers all pin the partial `self` and its mmaps). try: for f in self._files: if self._is_bin: # Raw binary: flat array of tokens. Infer dtype from file size # or use uint32 (most common for modern tokenizers with vocab > 65535) file_size = f.stat().st_size dtype = np.uint32 if file_size % 4 == 0 else np.uint16 n_tokens = file_size // np.dtype(dtype).itemsize mmap = np.memmap(str(f), dtype=dtype, mode="r", shape=(n_tokens,)) else: mmap = np.load(str(f), mmap_mode="r") n_samples = len(mmap) // seq_len self._mmaps.append(mmap) total_tokens += len(mmap) self._cumulative_samples.append(self._cumulative_samples[-1] + n_samples) except Exception: self._close_mmaps() raise self._total_samples = self._cumulative_samples[-1] logger.info( f"MemoryMappedDataset: {len(self._files)} files, " f"{total_tokens:,} tokens, {self._total_samples:,} samples (seq_len={seq_len})" ) # State for resumption self._epoch = 0
def __len__(self) -> int: return self._total_samples def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if idx < 0 or idx >= self._total_samples: raise IndexError(f"Index {idx} out of range [0, {self._total_samples})") # Binary search for the file containing this sample file_idx = self._find_file(idx) local_idx = idx - self._cumulative_samples[file_idx] start = local_idx * self.seq_len end = start + self.seq_len tokens = self._mmaps[file_idx][start:end].astype(np.int64) if self._pack_sequences: return _compute_packed_output(tokens, self._eos_token_id) # type: ignore[reportArgumentType] token_tensor = torch.from_numpy(tokens.copy()) # Input: tokens[:-1], Target: tokens[1:] (standard causal LM) return { "input_ids": token_tensor[:-1], "labels": token_tensor[1:], } def _find_file(self, idx: int) -> int: """Binary search for the file index containing global sample idx.""" lo, hi = 0, len(self._files) - 1 while lo < hi: mid = (lo + hi) // 2 if self._cumulative_samples[mid + 1] <= idx: lo = mid + 1 else: hi = mid return lo
[docs] def state_dict(self) -> dict: """Return checkpoint state. Keys: ``epoch``, ``total_samples``.""" return {"epoch": self._epoch, "total_samples": self._total_samples}
[docs] def load_state_dict(self, state: dict) -> None: """Restore from checkpoint. Only ``epoch`` is restored; sample count is derived.""" self._epoch = state.get("epoch", 0)
def _close_mmaps(self) -> None: """Release the underlying mmap objects. Idempotent.""" for mm in self._mmaps: inner = getattr(mm, "_mmap", None) if inner is not None and not inner.closed: # BufferError: live views into the mapping still exist — can't # force-close safely; drop the ref and let GC finish it. # ValueError: already closed by another code path. with contextlib.suppress(BufferError, ValueError): inner.close() self._mmaps.clear()
[docs] def close(self) -> None: """Release the underlying mmaps. Preferred path; do not rely on ``__del__``.""" self._close_mmaps()
def __del__(self) -> None: """GC safety net only. Prefer explicit :meth:`close`.""" with contextlib.suppress(Exception): self._close_mmaps()
[docs] class HuggingFaceDataset(Dataset): """HuggingFace dataset with on-the-fly tokenization and sequence packing. Loads a HuggingFace dataset, tokenizes text on the fly, and packs multiple documents into fixed-length sequences (separated by EOS tokens). Args: dataset_name: HuggingFace dataset name (e.g., "allenai/c4"). dataset_config: Optional config name (e.g., "wikitext-2-raw-v1"). split: Dataset split ("train", "validation", etc.). text_field: Name of the text column. seq_len: Sequence length for packing. tokenizer_path: Path or name for HuggingFace tokenizer. """
[docs] def __init__( self, dataset_name: str, split: str, text_field: str, seq_len: int, tokenizer_path: str, dataset_config: str | None = None, pack_sequences: bool = False, ) -> None: from datasets import load_dataset from transformers import AutoTokenizer self.seq_len = seq_len self.text_field = text_field self._packing_enabled = pack_sequences # Load tokenizer self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self._eos_id = self._tokenizer.eos_token_id or 0 # Load dataset and pack into fixed-length sequences logger.info(f"Loading HuggingFace dataset: {dataset_name} ({split})") raw_dataset = load_dataset(dataset_name, dataset_config, split=split) self._packed_sequences = self._pack_sequences(raw_dataset) logger.info( f"HuggingFaceDataset: {len(self._packed_sequences)} packed sequences " f"(seq_len={seq_len}) from {len(raw_dataset)} documents" # type: ignore[reportArgumentType] ) # State for resumption self._epoch = 0 self._sample_idx = 0
def _pack_sequences(self, raw_dataset) -> list[np.ndarray]: """Tokenize and pack documents into fixed-length sequences. Documents are concatenated with EOS separators, then sliced into chunks of exactly (seq_len + 1) tokens. The +1 provides the target for the last input position. """ chunk_size = self.seq_len + 1 # +1 for the target offset buffer: list[int] = [] packed: list[np.ndarray] = [] for example in raw_dataset: text = example[self.text_field] tokens = self._tokenizer.encode(text, add_special_tokens=False) if not tokens: continue buffer.extend(tokens) buffer.append(self._eos_id) # Flush full chunks from buffer while len(buffer) >= chunk_size: packed.append(np.array(buffer[:chunk_size], dtype=np.int64)) buffer = buffer[chunk_size:] # Discard partial remainder (no padding — clean sequences only) return packed def __len__(self) -> int: return len(self._packed_sequences) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: tokens = self._packed_sequences[idx] if self._packing_enabled: return _compute_packed_output(tokens, self._eos_id) token_tensor = torch.from_numpy(tokens.copy()) return { "input_ids": token_tensor[:-1], "labels": token_tensor[1:], }
[docs] def state_dict(self) -> dict: """Return checkpoint state. Keys: ``epoch``, ``sample_idx``, ``total_sequences``.""" return { "epoch": self._epoch, "sample_idx": self._sample_idx, "total_sequences": len(self._packed_sequences), }
[docs] def load_state_dict(self, state: dict) -> None: """Restore from checkpoint. Restores ``epoch`` and ``sample_idx``.""" self._epoch = state.get("epoch", 0) self._sample_idx = state.get("sample_idx", 0)
[docs] class StreamingHuggingFaceDataset(torch.utils.data.IterableDataset): """Streaming HuggingFace dataset with on-the-fly tokenization and packing. For very large datasets that don't fit in memory. Streams documents, tokenizes on the fly, and packs into fixed-length sequences. Handles distributed training by sharding the document stream across ranks (each rank processes every world_size-th document). Use directly with ``torch.utils.data.DataLoader`` (no sampler needed — IterableDataset handles its own distribution). Args: dataset_name: HuggingFace dataset name (e.g., "allenai/c4"). split: Dataset split ("train", "validation", etc.). text_field: Name of the text column. seq_len: Sequence length for packing. tokenizer_path: Path or name for HuggingFace tokenizer. dataset_config: Optional config name (e.g., "wikitext-2-raw-v1"). rank: Current distributed rank (for document sharding). world_size: Total number of ranks. seed: Random seed for shuffling. shuffle_buffer_size: Number of examples to buffer for shuffling. """
[docs] def __init__( self, dataset_name: str, split: str, text_field: str, seq_len: int, tokenizer_path: str, dataset_config: str | None = None, rank: int = 0, world_size: int = 1, seed: int = 42, shuffle_buffer_size: int = 10000, pack_sequences: bool = False, ) -> None: super().__init__() self.dataset_name = dataset_name self.dataset_config = dataset_config self.split = split self.text_field = text_field self.seq_len = seq_len self.tokenizer_path = tokenizer_path self.rank = rank self.world_size = world_size self.seed = seed self.shuffle_buffer_size = shuffle_buffer_size self._packing_enabled = pack_sequences # Lazy-init tokenizer (avoid loading before fork in multiprocessing workers) self._tokenizer = None self._eos_id = None # State for resumption self._epoch = 0 self._rank_docs_consumed = 0 self._skip_rank_docs = 0 logger.info( f"StreamingHuggingFaceDataset: {dataset_name} ({split}), " f"rank={rank}/{world_size}, seq_len={seq_len}" )
def _ensure_tokenizer(self): """Lazy-load tokenizer on first use.""" if self._tokenizer is None: from transformers import AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) self._eos_id = self._tokenizer.eos_token_id or 0 def _load_stream(self): """Load a shuffled HuggingFace streaming dataset.""" from datasets import load_dataset ds = load_dataset( self.dataset_name, self.dataset_config, split=self.split, streaming=True, ) # Shuffle with seed + epoch for different order each epoch ds = ds.shuffle(seed=self.seed + self._epoch, buffer_size=self.shuffle_buffer_size) # type: ignore[reportCallIssue] return ds def __iter__(self): self._ensure_tokenizer() chunk_size = self.seq_len + 1 # +1 for target offset buffer: list[int] = [] stream = self._load_stream() doc_idx = 0 # Global document counter (all ranks) rank_docs = 0 # Documents processed by this rank for example in stream: # Distributed sharding: each rank takes every world_size-th doc if doc_idx % self.world_size != self.rank: doc_idx += 1 continue # Skip documents for resumption (fast-forward to saved position) if rank_docs < self._skip_rank_docs: rank_docs += 1 doc_idx += 1 continue text = example[self.text_field] tokens = self._tokenizer.encode(text, add_special_tokens=False) # type: ignore[reportOptionalMemberAccess] if not tokens: rank_docs += 1 doc_idx += 1 continue buffer.extend(tokens) buffer.append(self._eos_id) # type: ignore[reportArgumentType] rank_docs += 1 doc_idx += 1 self._rank_docs_consumed = rank_docs # Yield full chunks from buffer while len(buffer) >= chunk_size: chunk = buffer[:chunk_size] buffer = buffer[chunk_size:] if self._packing_enabled: yield _compute_packed_output(np.array(chunk, dtype=np.int64), self._eos_id) # type: ignore[reportArgumentType] else: token_tensor = torch.tensor(chunk, dtype=torch.long) yield { "input_ids": token_tensor[:-1], "labels": token_tensor[1:], } # Epoch complete — reset for next iteration self._epoch += 1 self._skip_rank_docs = 0 self._rank_docs_consumed = 0
[docs] def state_dict(self) -> dict: """Return checkpoint state. Keys: ``epoch``, ``rank_docs_consumed``.""" return { "epoch": self._epoch, "rank_docs_consumed": self._rank_docs_consumed, }
[docs] def load_state_dict(self, state: dict) -> None: """Restore from checkpoint. Sets skip count to fast-forward on next iteration.""" self._epoch = state.get("epoch", 0) self._skip_rank_docs = state.get("rank_docs_consumed", 0) self._rank_docs_consumed = 0
[docs] class MixtureDataset(Dataset): """Concatenates multiple datasets for weighted mixing. Global index space maps to sub-datasets via cumulative offsets. Each sample includes ``dataset_idx`` (integer) so the training loop can compute per-dataset metrics. Args: datasets: List of map-style datasets to mix. names: Human-readable name per dataset (for metrics logging). """
[docs] def __init__(self, datasets: list[Dataset], names: list[str]) -> None: if len(datasets) != len(names): raise ValueError("datasets and names must have the same length") if not datasets: raise ValueError("At least one dataset is required") self._datasets = datasets self._names = names self._cumulative: list[int] = [0] for ds in datasets: self._cumulative.append(self._cumulative[-1] + len(ds)) # type: ignore[reportArgumentType] total = self._cumulative[-1] logger.info( f"MixtureDataset: {len(datasets)} sources, {total:,} total samples " f"({', '.join(f'{n}={len(d):,}' for n, d in zip(names, datasets, strict=True))})" # type: ignore[reportArgumentType] )
@property def cumulative_sizes(self) -> list[int]: """Cumulative dataset sizes: ``[0, len(ds0), len(ds0)+len(ds1), ...]``.""" return list(self._cumulative) @property def dataset_names(self) -> list[str]: return list(self._names) def __len__(self) -> int: return self._cumulative[-1] def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if idx < 0 or idx >= len(self): raise IndexError(f"Index {idx} out of range [0, {len(self)})") ds_idx = bisect.bisect_right(self._cumulative, idx) - 1 local_idx = idx - self._cumulative[ds_idx] sample = self._datasets[ds_idx][local_idx] sample["dataset_idx"] = ds_idx return sample
[docs] def state_dict(self) -> dict: """Return per-sub-dataset checkpoint state.""" return { f"dataset_{i}": ds.state_dict() # type: ignore[reportAttributeAccessIssue] for i, ds in enumerate(self._datasets) if hasattr(ds, "state_dict") }
[docs] def load_state_dict(self, state: dict) -> None: """Restore per-sub-dataset state.""" for i, ds in enumerate(self._datasets): key = f"dataset_{i}" if key in state and hasattr(ds, "load_state_dict"): ds.load_state_dict(state[key]) # type: ignore[reportAttributeAccessIssue]