tmrc Package

class tmrc.core.models.gpt.GPT(config, platform: Platform)[source]

Bases: Module

Basic implementation of GPT-ish variant architectures from OpenAI. This class follows Karpathy’s implementation almost exactly. Minor changes were made to:

  • validate config

  • simplify optimization, fuse optimizer step into backward pass [TODO]

  • simplify overall class, e.g., no longer use GPT weight init

  • using the Platform class to manage device training

  • move FlashAttention -> FlexAttention [TODO]

forward(idx, targets=None, doc_ids=None)[source]
Parameters:
  • idx – Token ids of the tokenized input; shape (B, L)

  • targets – Token ids of the expected next token; shape (B, L) Used for training

  • doc_ids – Document ids of the provided tokens; shape (B, L) Used for training

Shape legend:

B: Batch size L: Sequence length

static validate_config(config)[source]

Some basic sanity checks for the model config.

class tmrc.core.models.components.decoder.Block(hidden_size: int, attn_target: Callable[[int, float], SelfAttentionBase], ffn_target: Callable[[int, float], Module], norm_bias: bool, dropout: float)[source]

Bases: Module

Basic transformer block.

forward(x: Tensor, mask_data: Tensor | None, *, mask_key: str | None)[source]
Parameters:
  • query – Token hidden states; shape (B, L, E)

  • mask_data – Mask data used for constructing the attention mask Shape depends on the type of mask. See documentation of the SelfAttention subclass used.

Returns:

Attention values; shape (B, L, E)

Shape legend:

B: Batch size L: Sequence length E: Embedding dimension

class tmrc.core.models.components.decoder.CausalSelfAttention(*args, **kwargs)[source]

Bases: SelfAttentionFlash

Self-attention layer that uses a causal mask.

Input:

query: Token hidden states; shape (B, L, E) mask_data (ignored): Causal mask does not need any input data

Output:

Attention values; shape (B, L, E)

Shape legend:

B: Batch size L: Sequence length E: Embedding dimension

class tmrc.core.models.components.decoder.DocumentCausalSelfAttention(*args, **kwargs)[source]

Bases: SelfAttentionFlex

Self-attention layer that uses a document-causal mask.

Input:

query: Token hidden states; shape (B, L, E) mask_data: Document ids for the input tokens; shape (B, L)

Output:

Attention values; shape (B, L, E)

Shape legend:

B: Batch size L: Sequence length E: Embedding dimension

class tmrc.core.models.components.decoder.MLP(input_size: int, intermediate_size: int, activation: ~typing.Callable[[], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, bias: bool = True, dropout: float = 0.0)[source]

Bases: Module

Basic MLP block.

forward(x)[source]
Parameters:

x – input features; shape (*, dim), where * is any number of dimensions

Returns:

Output features; shape (*, dim), where * is any number of dimensions

class tmrc.core.models.components.decoder.SelfAttentionBase(hidden_size: int, num_attention_heads: int, qkv_bias: bool, proj_bias: bool, dropout: float)[source]

Bases: Generic[T], Module

Base class for all self-attention implementations.

All subclasses should be one of two forms:
  1. A class for a specific attention computation algorithm

    (e.g. FlashAttention and FlexAttention). These should inherit directly from SelfAttentionBase and implement the compute_attention function.

  2. A class for a specific mask. These should inherit from

    one of the subclasses of type 1, depending on which attention algorithm is optimal for that mask. These should implement the _generate_mask function.

Each subclass contains a mask cache (denoted MASK_CACHE) so that masks can be shared by all instances of the class. This is useful when a mask is computationally-intensive to construct, such as the block masks used for FlexAttention.

One can cache a mask by specifying a mask_key in the forward. If the mask_key is already in the cache, then the mask with that key in the cache will be reused. Otherwise, a new mask will be constructed and stored in the cache.

The mask cache should regularly be cleared to avoid memory problems. It is up to the user to call the clear_mask_cache function to free up memory.

classmethod clear_mask_cache()[source]

Clears the mask cache for the current class.

compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: T | None = None) Tensor[source]

Handles the attention calculation. Should be implemented in all subclasses of type 1 (see SelfAttentionBase class description).

Parameters:
  • query – Query tensor; shape (B, H, L, E)

  • key – Key tensor; shape (B, H, L, E)

  • value – Value tensor; shape (B, H, L, E)

  • attn_mask – Mask to use for the attention calculation. 1 means to attend to the token and 0 means to not attend.

Returns:

Attention output; shape (B, H, L, E)

Shape legend:

B: Batch size H: Number of heads L: Sequence length E: Embedding dimension per head

forward(tokens: Tensor, mask_data: Tensor | None = None, *, mask_key: str | None = None)[source]
Parameters:
  • query – Token hidden states; shape (B, L, E)

  • mask_data – Mask data needed to construct the mask; shape depends on mask

Returns:

Attention values; shape (B, L, E)

Shape legend:

B: Batch size L: Sequence length E: Embedding dimension

generate_mask(mask_data: Tensor, *, mask_key: str | None = None) T | None[source]

Generates the mask for the attention calcuation based on the given input mask. If mask_key is provided, then the attention computation will reuse the mask in the cache with the given key if it is available. If it is not available, then a new mask will be constructed and stored in the cache with that key.

Parameters:
  • mask_data – Data describing the mask to construct.

  • mask_key – Key to use for the mask cache.

class tmrc.core.models.components.decoder.SelfAttentionFlash(*args, is_causal=True, **kwargs)[source]

Bases: SelfAttentionBase[Tensor]

Self-attention layer that uses FlashAttention for the attention computation. Should be used for masks where FlashAttention provides optimal performance.

compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None) Tensor[source]

Computes attention via the FlashAttention algorithm.

class tmrc.core.models.components.decoder.SelfAttentionFlex(*args, **kwargs)[source]

Bases: SelfAttentionBase[BlockMask]

Self-attention layer that uses FlexAttention for the attention computation. Should be used for masks where FlexAttention provides optimal performance.

compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: BlockMask | None = None) Tensor[source]

Computes attention using the FlexAttention algorithm.

class tmrc.core.models.components.decoder.SelfAttentionManual(hidden_size: int, num_attention_heads: int, qkv_bias: bool, proj_bias: bool, dropout: float)[source]

Bases: SelfAttentionBase[Tensor]

Self-attention layer that uses a manual attention computation. This module is here as an example. Users should use SelfAttentionFlash or SelfAttentionFlex for optimal performance.

compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None) Tensor[source]

Handles the attention calculation. Should be implemented in all subclasses of type 1 (see SelfAttentionBase class description).

Parameters:
  • query – Query tensor; shape (B, H, L, E)

  • key – Key tensor; shape (B, H, L, E)

  • value – Value tensor; shape (B, H, L, E)

  • attn_mask – Mask to use for the attention calculation. 1 means to attend to the token and 0 means to not attend.

Returns:

Attention output; shape (B, H, L, E)

Shape legend:

B: Batch size H: Number of heads L: Sequence length E: Embedding dimension per head

class tmrc.core.models.components.decoder.SwiGLUFFN(input_size: int, intermediate_size: int, bias: bool = True, norm_layer: Callable[[int], Module] | None = None)[source]

Bases: Module

Feed-forward block using SwiGLU

forward(x: Tensor) Tensor[source]
Parameters:

x – input features; shape (*, dim), where * is any number of dimensions

Returns:

Output features; shape (*, dim), where * is any number of dimensions

class tmrc.core.models.components.quantizer.VectorQuantizer(num_tokens: int, embedding_size: int, beta: float)[source]

Bases: Module

Handles quantizing a continuous representation into a discrete one. Follows the implementation described in the VQ-VAE paper (https://arxiv.org/abs/1711.00937).

forward(z: Tensor)[source]
Parameters:

z – Input vector to quantize; shape (B, E)

Returns:

Continuous representation of the discrete token for z; shape (B,E)

Shape legend:

B: Batch size E: Embedding dimension

quantize(z: Tensor) int[source]

Returns index of the closest embeeding to the input.

Parameters:

z – Input vector to quantize; shape (B, E)

Returns:

Discrete token ids; shape (B,)

Shape legend:

B: Batch size E: Embedding dimension

tmrc.core.training.data.create_dataloaders(datasets_path: str, context_length: int, batch_size: int, platform: Platform)[source]

Note: until tatm supports validation set, we will use the first config.training.batch_size samples as validation set

  • [TODO] implement proper validation set

  • [TODO] more data checks to validate batch sizes, document ids, etc.

    are consistent across all batches

Args:

datasets_path: Location of the datasets to use context_length: Context length of the model to train batch_size: Batch size for training

class tmrc.core.training.train.ProfilerParams(wait: int, warmup: int, active: int, repeat: int)[source]

Bases: object

class tmrc.core.training.train.TrainingParams(train_steps: int, epochs: int, autocast_precision: str, val_interval: int, artifacts_path: str, save_model: bool, log_interval: int)[source]

Bases: object

tmrc.core.training.train.get_dist_model(model, config, platform)[source]
Parameters:
  • model – Model to wrap for distributed training

  • config – Hydra config to get the distributed strategy

  • platform – Platform describing distributed settings and devices

tmrc.core.training.train.init_wandb(project: str, config: dict)[source]
Parameters:
  • project – Name of the wandb project

  • config – Training config to be show for the current run on wandb

tmrc.core.training.train.log_model_info(model: Module)[source]
Parameters:

model – Model to print the info of for debugging

tmrc.core.training.train.save_model_periodic(model: Module, save_dir: str, interval: int, stop_event: Event)[source]

Thread function to periodically save model

tmrc.core.training.train.train(model: Module, platform: Platform, train_loader: DataLoader, val_loader: DataLoader, optimizer: Optimizer, training_params: TrainingParams, train_sampler: DistributedSampler | None = None, profiler_params: ProfilerParams | None = None)[source]
Parameters:
  • model – Model to train

  • platform – Platform describing devices to train with

  • train_loader – Data loader for the training data

  • val_loader – Data loader for the validation data

  • optimizer – Optimizer for training

  • training_params – Training-related parameters detailing training steps, checkpointing, and logging

  • wandb_project – Name of the project on wandb

  • wandb_config – Training config to show for the current run on wandb

  • profiling_params – Profiling-related parameters. Handles the tracing schedule

class tmrc.core.utils.platform.Platform[source]

Bases: object

A basic singleton platform abstraction to manage distributed training. Or, if we are not running on multiple nodes, identifies the single GPU device.

get_gpu_memory_info() dict[source]

Returns the GPU memory information of the local rank.

get_gpu_memory_peak() float | None[source]

Returns the peak GPU memory usage of the local rank. If distributed, Rank 0 will return the maximum peak GPU memory of all ranks.

tmrc.core.utils.registry.register_optimizer(key=None)

Model registry