Source code for tmrc.core.training.data

from tatm.data import get_dataset, torch_collate_fn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from tmrc import logger
from tmrc.core.utils.platform import Platform


[docs] def create_dataloaders( datasets_path: str, context_length: int, batch_size: int, platform: Platform ): """ Note: until `tatm` supports validation set, we will use the first `config.training.batch_size` samples as validation set - [TODO] implement proper validation set - [TODO] more data checks to validate batch sizes, document ids, etc. are consistent across all batches Args: datasets_path: Location of the datasets to use context_length: Context length of the model to train batch_size: Batch size for training """ full_dataset = get_dataset(datasets_path, context_length=context_length) logger.info(f"Dataset length: {len(full_dataset)}") val_dataset = [full_dataset[d] for d in range(batch_size)] platform.synchronize() logger.info("Creating distibuted samplers") train_sampler = ( DistributedSampler( full_dataset, num_replicas=platform.world_size, rank=platform.global_rank ) if platform.distributed else None ) val_sampler = ( DistributedSampler( val_dataset, num_replicas=platform.world_size, rank=platform.global_rank, shuffle=False, ) if platform.distributed else None ) logger.info("Creating dataloaders") train_loader = DataLoader( full_dataset, batch_size=batch_size, num_workers=4, pin_memory=True, collate_fn=torch_collate_fn, sampler=train_sampler, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=torch_collate_fn, sampler=val_sampler, ) return train_loader, val_loader, train_sampler