kempnerforge.data.sampler

Distributed sampler with deterministic shuffling and skip-ahead.

Correctly partitions data across data-parallel ranks with:
  • Epoch-based re-shuffling with deterministic seeds

  • Skip-ahead for exact resumption after checkpoint load

  • Handling of uneven dataset sizes (drop last partial batch)

Classes

DistributedSampler

Deterministic distributed sampler with skip-ahead support.

MixtureSampler

Weighted sampler over a MixtureDataset.

class kempnerforge.data.sampler.DistributedSampler[source]

Bases: ~torch.utils.data.Sampler.

Deterministic distributed sampler with skip-ahead support.

Partitions dataset indices across data-parallel ranks. Each rank sees a unique, non-overlapping subset of the data.

Parameters:
  • dataset – Dataset to sample from.

  • num_replicas – Number of data-parallel ranks (default: world_size).

  • rank – Current rank (default: from dist).

  • shuffle – Whether to shuffle indices.

  • seed – Base random seed for deterministic shuffling.

  • drop_last – Drop samples that don’t divide evenly across ranks.

__init__(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=True)[source]
Parameters:
Return type:

None

set_epoch(epoch)[source]

Set the epoch for deterministic re-shuffling.

Parameters:

epoch (int)

Return type:

None

set_skip(skip)[source]

Set number of samples to skip (for resumption after checkpoint).

Parameters:

skip (int)

Return type:

None

state_dict()[source]

Return checkpoint state. Keys: epoch, seed, num_replicas, rank.

Return type:

dict

load_state_dict(state)[source]

Restore from checkpoint. Only epoch is restored; rank info is local.

Parameters:

state (dict)

Return type:

None

class kempnerforge.data.sampler.MixtureSampler[source]

Bases: ~torch.utils.data.Sampler.

Weighted sampler over a MixtureDataset.

Each sub-dataset’s indices are partitioned across ranks (like DistributedSampler). The weights control what fraction of the epoch is drawn from each dataset — datasets with higher weight are oversampled.

Parameters:
  • cumulative_sizes[0, len(ds0), len(ds0)+len(ds1), ...] from MixtureDataset.cumulative_sizes.

  • weights – Per-dataset sampling weights (normalized internally).

  • num_replicas – Number of data-parallel ranks.

  • rank – Current rank.

  • shuffle – Whether to shuffle indices.

  • seed – Base random seed.

  • drop_last – Drop samples that don’t divide evenly across ranks.

  • temperature – Weight temperature (1.0 = as-is, >1 → more uniform).

__init__(cumulative_sizes, weights, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=True, temperature=1.0)[source]
Parameters:
Return type:

None

set_epoch(epoch)[source]

Set epoch for deterministic re-shuffling.

Parameters:

epoch (int)

Return type:

None

set_skip(skip)[source]

Set number of samples to skip (for resumption after checkpoint).

Parameters:

skip (int)

Return type:

None

state_dict()[source]

Return checkpoint state.

Return type:

dict

update_weights(weights, temperature=1.0)[source]

Update sampling weights for phase transitions.

Recomputes internal probabilities and per-dataset target counts. Takes effect on the next __iter__() call.

Parameters:
Return type:

None

load_state_dict(state)[source]

Restore from checkpoint.

Parameters:

state (dict)

Return type:

None