Source code for kempnerforge.data.dataloader
"""Distributed, stateful DataLoader for KempnerForge.
Wraps PyTorch DataLoader with:
- Distributed-aware setup (correct worker count, pinned memory)
- Stateful iteration tracking for checkpoint/resume
- Integration with DistributedSampler for rank-partitioned data
"""
from __future__ import annotations
import logging
import torch
from torch.utils.data import DataLoader, Dataset
from kempnerforge.config.schema import DataConfig
from kempnerforge.data.sampler import DistributedSampler, MixtureSampler
logger = logging.getLogger(__name__)
[docs]
class StatefulDataLoader:
"""Stateful wrapper around PyTorch DataLoader.
Tracks iteration progress so training can resume from the exact
position after a checkpoint load.
Args:
dataset: Dataset to load from.
batch_size: Per-device micro-batch size.
sampler: Distributed sampler (created automatically if None).
config: Data pipeline configuration.
"""
[docs]
def __init__(
self,
dataset: Dataset,
batch_size: int,
sampler: DistributedSampler | MixtureSampler | None = None,
config: DataConfig | None = None,
) -> None:
config = config or DataConfig()
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler or DistributedSampler(dataset)
self._dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=self.sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
persistent_workers=config.num_workers > 0,
drop_last=True,
)
# State tracking
self._epoch = 0
self._batches_yielded = 0
self._iterator = None
logger.info(
f"StatefulDataLoader: batch_size={batch_size}, "
f"workers={config.num_workers}, pin_memory={config.pin_memory}"
)
def __iter__(self):
self.sampler.set_epoch(self._epoch)
self._iterator = iter(self._dataloader)
self._batches_yielded = 0
return self
def __next__(self) -> dict[str, torch.Tensor]:
if self._iterator is None:
raise StopIteration
try:
batch = next(self._iterator)
self._batches_yielded += 1
return batch
except StopIteration:
self._epoch += 1
self._batches_yielded = 0
self._iterator = None
raise
def __len__(self) -> int:
return len(self._dataloader)
[docs]
def state_dict(self) -> dict:
"""Return checkpoint state. Keys: ``epoch``, ``batches_yielded``, ``sampler``."""
return {
"epoch": self._epoch,
"batches_yielded": self._batches_yielded,
"sampler": self.sampler.state_dict(),
}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Restore from checkpoint. Restores sampler state and skips to saved batch position."""
self._epoch = state.get("epoch", 0)
batches_yielded = state.get("batches_yielded", 0)
# Set sampler state for resumption
if "sampler" in state:
self.sampler.load_state_dict(state["sampler"])
# Skip ahead to the correct position in the current epoch
if batches_yielded > 0:
self.sampler.set_skip(batches_yielded * self.batch_size)
logger.info(f"Resumed DataLoader: epoch={self._epoch}, skip_batches={batches_yielded}")