kempnerforge.data.sampler¶
Distributed sampler with deterministic shuffling and skip-ahead.
- Correctly partitions data across data-parallel ranks with:
Epoch-based re-shuffling with deterministic seeds
Skip-ahead for exact resumption after checkpoint load
Handling of uneven dataset sizes (drop last partial batch)
Classes
Deterministic distributed sampler with skip-ahead support. |
|
Weighted sampler over a |
- class kempnerforge.data.sampler.DistributedSampler[source]¶
Bases:
~torch.utils.data.Sampler.Deterministic distributed sampler with skip-ahead support.
Partitions dataset indices across data-parallel ranks. Each rank sees a unique, non-overlapping subset of the data.
- Parameters:
dataset – Dataset to sample from.
num_replicas – Number of data-parallel ranks (default: world_size).
rank – Current rank (default: from dist).
shuffle – Whether to shuffle indices.
seed – Base random seed for deterministic shuffling.
drop_last – Drop samples that don’t divide evenly across ranks.
- set_epoch(epoch)[source]¶
Set the epoch for deterministic re-shuffling.
- Parameters:
epoch (int)
- Return type:
None
- set_skip(skip)[source]¶
Set number of samples to skip (for resumption after checkpoint).
- Parameters:
skip (int)
- Return type:
None
- class kempnerforge.data.sampler.MixtureSampler[source]¶
Bases:
~torch.utils.data.Sampler.Weighted sampler over a
MixtureDataset.Each sub-dataset’s indices are partitioned across ranks (like
DistributedSampler). Theweightscontrol what fraction of the epoch is drawn from each dataset — datasets with higher weight are oversampled.- Parameters:
cumulative_sizes –
[0, len(ds0), len(ds0)+len(ds1), ...]fromMixtureDataset.cumulative_sizes.weights – Per-dataset sampling weights (normalized internally).
num_replicas – Number of data-parallel ranks.
rank – Current rank.
shuffle – Whether to shuffle indices.
seed – Base random seed.
drop_last – Drop samples that don’t divide evenly across ranks.
temperature – Weight temperature (1.0 = as-is, >1 → more uniform).
- __init__(cumulative_sizes, weights, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=True, temperature=1.0)[source]¶
- set_epoch(epoch)[source]¶
Set epoch for deterministic re-shuffling.
- Parameters:
epoch (int)
- Return type:
None
- set_skip(skip)[source]¶
Set number of samples to skip (for resumption after checkpoint).
- Parameters:
skip (int)
- Return type:
None