Dataset Splitting
The TatmDataset
class includes functionality for creating simple index based train and validation scripts
that can be used to separate the dataset into training and validation sets. The functionality as implemented
allows for users to either specify a number of indices or a percentage of the dataset to be used for validation.
The split will be deterministic based on the index in the full dataset and will not change between runs or if the
same dataset is loaded multiple times.
Example loading a dataset and splitting it into training and validation sets
from tatm import get_dataset
dataset = get_dataset("my_data", context_length=512, val_split_size=0.1)
print(len(dataset))
# 1000
The proceeding code will load the dataset and tell the dataset object to prepare to split the dataset into a training
and validation set where the validation set will be 10% of the full dataset. However if we call len
on the dataset
at this point we will see that the dataset is still the full dataset.
If we want to use the training set for the split we have two possible approaches. The first is to call the set_split
method on the dataset object and pass in the string “train” as the argument. The second is to pass “train” as the
split argument when initializing the dataset.
dataset.set_split("train")
print(len(dataset))
# 900
train_dataset = get_dataset("my_data", context_length=512, val_split_size=0.1, split="train")
print(len(train_dataset))
# 900
We can also use the same approach to get the validation set.
dataset.set_split("val")
print(len(dataset))
# 100
val_dataset = get_dataset("my_data", context_length=512, val_split_size=150, split="val") # we can also pass in a number of items to use for the validation set
print(len(val_dataset))
# 150
Note that we can use the set_split
method to switch between the training and validation sets at any time. If we want to operate on the full dataset we can call set_split(None)
or pass None
as the split argument when initializing the dataset. If we have loaded a dataset without defining a split size, we can still create a split by calling the create_split
method and passing in the desired split size. This will create a new split based on the current dataset and the specified split size. Note that this has to be done prior to calling set_split
or passing in a split argument when initializing the dataset.
dataset = get_dataset("my_data", context_length=512)
print(len(dataset))
# 1000
dataset.create_split(0.1) # create a split with 10% of the dataset reserved for validation
print(len(dataset))
# 1000
dataset.set_split("train")
print(len(dataset))
# 900
dataset.set_split("val")
print(len(dataset))
# 100
dataset.set_split(None) # set the split to None to use the full dataset
print(len(dataset))
# 1000
When we have the splits created, the indices used to return items in the dataset will be remapped to only return items from the split that we are using. Note that
this also means in the case of the validation split, indices will be remapped so that the first index in the validation split can be returned by calling dataset[0]
.
dataset = get_dataset("my_data", context_length=512, val_split_size=0.2)
dataset.set_split(None)
val_dataset = get_dataset("my_data", context_length=512, val_split_size=0.2, split="val")
print(dataset[800] == val_dataset[0])
# True
With these features in place, we can easily create training and validation sets for our dataset and use them in training and evaluation loops as drop in replacements for the full dataset.