from tatm.data import get_dataset, torch_collate_fn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tmrc import logger
from tmrc.core.utils.platform import Platform
[docs]
def create_dataloaders(
datasets_path: str, context_length: int, batch_size: int, platform: Platform
):
"""
Note: until `tatm` supports validation set, we will use
the first `config.training.batch_size` samples as validation set
- [TODO] implement proper validation set
- [TODO] more data checks to validate batch sizes, document ids, etc.
are consistent across all batches
Args:
datasets_path: Location of the datasets to use
context_length: Context length of the model to train
batch_size: Batch size for training
"""
full_dataset = get_dataset(datasets_path, context_length=context_length)
logger.info(f"Dataset length: {len(full_dataset)}")
val_dataset = [full_dataset[d] for d in range(batch_size)]
platform.synchronize()
logger.info("Creating distibuted samplers")
train_sampler = (
DistributedSampler(
full_dataset, num_replicas=platform.world_size, rank=platform.global_rank
)
if platform.distributed
else None
)
val_sampler = (
DistributedSampler(
val_dataset,
num_replicas=platform.world_size,
rank=platform.global_rank,
shuffle=False,
)
if platform.distributed
else None
)
logger.info("Creating dataloaders")
train_loader = DataLoader(
full_dataset,
batch_size=batch_size,
num_workers=4,
pin_memory=True,
collate_fn=torch_collate_fn,
sampler=train_sampler,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=torch_collate_fn,
sampler=val_sampler,
)
return train_loader, val_loader, train_sampler