Getting Started



  • Python 3.10 is the minimum supported version.

  • tatm depends on pytorch, which depends on CUDA. It is recommended to pre-install both pytorch and CUDA prior to installing tatm. Instructions for installing pytorch can be found here.

Installing from GitHub

To install the latest stable version of tatm from GitHub, run the following command:

pip install git+ssh://

To install the latest development version of tatm from GitHub, run the following command:

pip install git+ssh://

For a specific past version of tatm, replace main or dev with the desired version number (i.e. v0.1.0).

Installing from PyPI

The package is not yet available on PyPI. Stay tuned for updates!

Loading Tokenized Data with tatm for use with PyTorch

In the example code below, we show how to create a PyTorch dataloader with a tokenized dataset for use with a model.


If your site is set up with a metadata backend you can use semantic names for the dataset instead of the path to the tokenized data. See Metadata Store Setup for more information.

import numpy as np
import torch
from import DataLoader

from import get_dataset, torch_collate_fn
tatm_dataset = get_dataset("<PATH TO TATM TOKENIZED DATA>", context_length=1024)
len(tatm_dataset) # number of examples in set
# 35651584
# 36507222016
# 34
# 32100
# Note that the output will vary depending on the dataset and the tokenization process as the order documents are tokenized may vary.
# TatmMemmapDatasetItem(
#    token_ids=array([    7,    16,     8, ..., 14780,     8,  2537], dtype=uint16), 
#    document_ids=array([0, 0, 0, ..., 1, 1, 1], dtype=uint16)
# )

dataloader = DataLoader(tatm_dataset, batch_size=4, collate_fn=torch_collate_fn)
# {'token_ids': tensor([[    3,     2, 14309,  ...,  1644,  4179,    16],
#         [ 3731,  3229,     2,  ...,    15,     2,     3],
#         [    2, 14309,     2,  ...,   356,     5, 22218],
#         [    7,    16,     8,  ..., 14780,     8,  2537]], dtype=torch.uint16), 
#    'document_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
#         [0, 0, 0,  ..., 0, 0, 0],
#         [0, 0, 0,  ..., 0, 0, 0],
#         [0, 0, 0,  ..., 1, 1, 1]], dtype=torch.uint16)}