"""Distributed sampler with deterministic shuffling and skip-ahead.
Correctly partitions data across data-parallel ranks with:
- Epoch-based re-shuffling with deterministic seeds
- Skip-ahead for exact resumption after checkpoint load
- Handling of uneven dataset sizes (drop last partial batch)
"""
from __future__ import annotations
import math
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
[docs]
class DistributedSampler(Sampler[int]):
"""Deterministic distributed sampler with skip-ahead support.
Partitions dataset indices across data-parallel ranks. Each rank
sees a unique, non-overlapping subset of the data.
Args:
dataset: Dataset to sample from.
num_replicas: Number of data-parallel ranks (default: world_size).
rank: Current rank (default: from dist).
shuffle: Whether to shuffle indices.
seed: Base random seed for deterministic shuffling.
drop_last: Drop samples that don't divide evenly across ranks.
"""
[docs]
def __init__(
self,
dataset: Dataset,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = True,
) -> None:
if num_replicas is None:
num_replicas = dist.get_world_size() if dist.is_initialized() else 1
if rank is None:
rank = dist.get_rank() if dist.is_initialized() else 0
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
# Compute per-rank sample count
total = len(dataset) # type: ignore[reportArgumentType]
if drop_last:
# Drop remainder so all ranks get the same count
self.num_samples = total // num_replicas
self.total_size = self.num_samples * num_replicas
else:
# Pad to make evenly divisible
self.num_samples = math.ceil(total / num_replicas)
self.total_size = self.num_samples * num_replicas
# State for resumption
self._epoch = 0
self._skip = 0 # Number of samples to skip (for resumption)
[docs]
def set_epoch(self, epoch: int) -> None:
"""Set the epoch for deterministic re-shuffling."""
self._epoch = epoch
[docs]
def set_skip(self, skip: int) -> None:
"""Set number of samples to skip (for resumption after checkpoint)."""
self._skip = skip
def __iter__(self):
# Generate deterministic permutation
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self._epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[reportArgumentType]
else:
indices = list(range(len(self.dataset))) # type: ignore[reportArgumentType]
# Handle uneven sizes
if self.drop_last:
indices = indices[: self.total_size]
else:
# Pad by wrapping around
padding = self.total_size - len(indices)
indices += indices[:padding]
# Partition: each rank gets every num_replicas-th element
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
# Skip-ahead for resumption
if self._skip > 0:
indices = indices[self._skip :]
self._skip = 0 # Reset after applying
return iter(indices)
def __len__(self) -> int:
return self.num_samples
[docs]
def state_dict(self) -> dict:
"""Return checkpoint state. Keys: ``epoch``, ``seed``, ``num_replicas``, ``rank``."""
return {
"epoch": self._epoch,
"seed": self.seed,
"num_replicas": self.num_replicas,
"rank": self.rank,
}
[docs]
def load_state_dict(self, state: dict) -> None:
"""Restore from checkpoint. Only ``epoch`` is restored; rank info is local."""
self._epoch = state.get("epoch", 0)
[docs]
class MixtureSampler(Sampler[int]):
"""Weighted sampler over a :class:`MixtureDataset`.
Each sub-dataset's indices are partitioned across ranks (like
``DistributedSampler``). The ``weights`` control what fraction of
the epoch is drawn from each dataset — datasets with higher weight
are oversampled.
Args:
cumulative_sizes: ``[0, len(ds0), len(ds0)+len(ds1), ...]``
from ``MixtureDataset.cumulative_sizes``.
weights: Per-dataset sampling weights (normalized internally).
num_replicas: Number of data-parallel ranks.
rank: Current rank.
shuffle: Whether to shuffle indices.
seed: Base random seed.
drop_last: Drop samples that don't divide evenly across ranks.
temperature: Weight temperature (1.0 = as-is, >1 → more uniform).
"""
[docs]
def __init__(
self,
cumulative_sizes: list[int],
weights: list[float],
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = True,
temperature: float = 1.0,
) -> None:
if num_replicas is None:
num_replicas = dist.get_world_size() if dist.is_initialized() else 1
if rank is None:
rank = dist.get_rank() if dist.is_initialized() else 0
n = len(cumulative_sizes) - 1
self._dataset_sizes = [cumulative_sizes[i + 1] - cumulative_sizes[i] for i in range(n)]
self._offsets = list(cumulative_sizes[:n])
# Apply temperature scaling and normalize
if temperature != 1.0:
import math as _math
log_w = [_math.log(max(w, 1e-12)) / temperature for w in weights]
max_lw = max(log_w)
scaled = [_math.exp(lw - max_lw) for lw in log_w]
total = sum(scaled)
self._probs = [s / total for s in scaled]
else:
total_w = sum(weights)
self._probs = [w / total_w for w in weights]
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
# Per-dataset per-rank available count
per_rank_avail = []
for size in self._dataset_sizes:
if drop_last:
per_rank_avail.append(size // num_replicas)
else:
per_rank_avail.append(math.ceil(size / num_replicas))
# Weighted allocation: how many samples from each dataset per epoch
total_per_rank = sum(per_rank_avail)
self._target_counts = [round(p * total_per_rank) for p in self._probs]
# Fix rounding to match total exactly
diff = total_per_rank - sum(self._target_counts)
sorted_idx = sorted(range(n), key=lambda i: -self._probs[i])
for i in range(abs(diff)):
idx = sorted_idx[i % n]
self._target_counts[idx] += 1 if diff > 0 else -1
self.num_samples = sum(self._target_counts)
self._epoch = 0
self._skip = 0
[docs]
def set_epoch(self, epoch: int) -> None:
"""Set epoch for deterministic re-shuffling."""
self._epoch = epoch
[docs]
def set_skip(self, skip: int) -> None:
"""Set number of samples to skip (for resumption after checkpoint)."""
self._skip = skip
def __len__(self) -> int:
return self.num_samples
def __iter__(self):
g = torch.Generator().manual_seed(self.seed + self._epoch)
result: list[int] = []
for ds_i in range(len(self._dataset_sizes)):
size = self._dataset_sizes[ds_i]
offset = self._offsets[ds_i]
target = self._target_counts[ds_i]
if target <= 0 or size == 0:
continue
# Shuffled local indices for this dataset
if self.shuffle:
indices = torch.randperm(size, generator=g).tolist()
else:
indices = list(range(size))
# Partition for this rank (stride-based, like DistributedSampler)
if self.drop_last:
usable = size - (size % self.num_replicas)
indices = indices[:usable]
else:
padding = (self.num_replicas - len(indices) % self.num_replicas) % self.num_replicas
if padding:
indices = indices + indices[:padding]
rank_indices = indices[self.rank :: self.num_replicas]
if not rank_indices:
continue
# Draw target samples (wrap around for oversampling)
if target <= len(rank_indices):
drawn = rank_indices[:target]
else:
reps = target // len(rank_indices) + 1
drawn = (rank_indices * reps)[:target]
# Convert to global MixtureDataset indices
result.extend(idx + offset for idx in drawn)
# Shuffle all indices together for random interleaving
if self.shuffle:
perm = torch.randperm(len(result), generator=g)
result = [result[p] for p in perm.tolist()]
# Skip for resumption
if self._skip > 0:
result = result[self._skip :]
self._skip = 0
return iter(result)
[docs]
def state_dict(self) -> dict:
"""Return checkpoint state."""
return {
"epoch": self._epoch,
"seed": self.seed,
"num_replicas": self.num_replicas,
"rank": self.rank,
}
[docs]
def update_weights(self, weights: list[float], temperature: float = 1.0) -> None:
"""Update sampling weights for phase transitions.
Recomputes internal probabilities and per-dataset target counts.
Takes effect on the next ``__iter__()`` call.
"""
n = len(self._dataset_sizes)
if len(weights) != n:
raise ValueError(f"Expected {n} weights, got {len(weights)}")
# Apply temperature scaling and normalize (same logic as __init__)
if temperature != 1.0:
import math as _math
log_w = [_math.log(max(w, 1e-12)) / temperature for w in weights]
max_lw = max(log_w)
scaled = [_math.exp(lw - max_lw) for lw in log_w]
total = sum(scaled)
self._probs = [s / total for s in scaled]
else:
total_w = sum(weights)
self._probs = [w / total_w for w in weights]
# Recompute per-dataset per-rank available count
per_rank_avail = []
for size in self._dataset_sizes:
if self.drop_last:
per_rank_avail.append(size // self.num_replicas)
else:
per_rank_avail.append(math.ceil(size / self.num_replicas))
total_per_rank = sum(per_rank_avail)
self._target_counts = [round(p * total_per_rank) for p in self._probs]
# Fix rounding to match total exactly
diff = total_per_rank - sum(self._target_counts)
sorted_idx = sorted(range(n), key=lambda i: -self._probs[i])
for i in range(abs(diff)):
idx = sorted_idx[i % n]
self._target_counts[idx] += 1 if diff > 0 else -1
self.num_samples = sum(self._target_counts)
[docs]
def load_state_dict(self, state: dict) -> None:
"""Restore from checkpoint."""
self._epoch = state.get("epoch", 0)