Source code for kempnerforge.data.dataloader

"""Distributed, stateful DataLoader for KempnerForge.

Wraps PyTorch DataLoader with:
  - Distributed-aware setup (correct worker count, pinned memory)
  - Stateful iteration tracking for checkpoint/resume
  - Integration with DistributedSampler for rank-partitioned data
"""

from __future__ import annotations

import logging
from collections.abc import Callable

import torch
from torch.utils.data import DataLoader, Dataset

from kempnerforge.config.schema import DataConfig
from kempnerforge.data.sampler import DistributedSampler, MixtureSampler

logger = logging.getLogger(__name__)


[docs] class StatefulDataLoader: """Stateful wrapper around PyTorch DataLoader. Tracks iteration progress so training can resume from the exact position after a checkpoint load. Args: dataset: Dataset to load from. batch_size: Per-device micro-batch size. sampler: Distributed sampler (created automatically if None). config: Data pipeline configuration. collate_fn: Optional custom batch collator. When None, uses PyTorch's default collation. VLM training passes ``VLMCollator`` so that the fixed-length padding and ``image_positions`` slot reach the batch. """
[docs] def __init__( self, dataset: Dataset, batch_size: int, sampler: DistributedSampler | MixtureSampler | None = None, config: DataConfig | None = None, collate_fn: Callable | None = None, ) -> None: config = config or DataConfig() self.dataset = dataset self.batch_size = batch_size self.sampler = sampler or DistributedSampler(dataset) loader_kwargs: dict = { "batch_size": batch_size, "sampler": self.sampler, "num_workers": config.num_workers, "pin_memory": config.pin_memory, "prefetch_factor": config.prefetch_factor if config.num_workers > 0 else None, "persistent_workers": config.num_workers > 0, "drop_last": True, } if collate_fn is not None: loader_kwargs["collate_fn"] = collate_fn self._dataloader = DataLoader(dataset, **loader_kwargs) # State tracking self._epoch = 0 self._batches_yielded = 0 self._iterator = None logger.info( f"StatefulDataLoader: batch_size={batch_size}, " f"workers={config.num_workers}, pin_memory={config.pin_memory}" )
def __iter__(self): self.sampler.set_epoch(self._epoch) # Re-apply skip on every iter() so double-resume within the same epoch # stays aligned. The sampler consumes _skip once per iter(), and # _batches_yielded persists across save/load so the skip is re-computable. if self._batches_yielded > 0: self.sampler.set_skip(self._batches_yielded * self.batch_size) self._iterator = iter(self._dataloader) return self def __next__(self) -> dict[str, torch.Tensor]: if self._iterator is None: raise StopIteration try: batch = next(self._iterator) self._batches_yielded += 1 return batch except StopIteration: self._epoch += 1 self._batches_yielded = 0 self._iterator = None raise def __len__(self) -> int: return len(self._dataloader)
[docs] def state_dict(self) -> dict: """Return checkpoint state. Keys: ``epoch``, ``batches_yielded``, ``sampler``.""" return { "epoch": self._epoch, "batches_yielded": self._batches_yielded, "sampler": self.sampler.state_dict(), }
[docs] def load_state_dict(self, state: dict) -> None: """Restore from checkpoint. ``__iter__`` re-applies the sampler skip from ``_batches_yielded``, so double-resume within the same epoch stays aligned. """ self._epoch = state.get("epoch", 0) self._batches_yielded = state.get("batches_yielded", 0) # Set sampler state for resumption if "sampler" in state: self.sampler.load_state_dict(state["sampler"]) logger.info( f"Resumed DataLoader: epoch={self._epoch}, skip_batches={self._batches_yielded}" )