tmrc Package
- class tmrc.core.models.gpt.GPT(config, platform: Platform)[source]
Bases:
Module
Basic implementation of GPT-ish variant architectures from OpenAI. This class follows Karpathy’s implementation almost exactly. Minor changes were made to:
validate config
simplify optimization, fuse optimizer step into backward pass [TODO]
simplify overall class, e.g., no longer use GPT weight init
using the Platform class to manage device training
move FlashAttention -> FlexAttention [TODO]
- forward(idx, targets=None, doc_ids=None)[source]
- Parameters:
idx – Token ids of the tokenized input; shape (B, L)
targets – Token ids of the expected next token; shape (B, L) Used for training
doc_ids – Document ids of the provided tokens; shape (B, L) Used for training
- Shape legend:
B: Batch size L: Sequence length
- class tmrc.core.models.components.decoder.Block(hidden_size: int, attn_target: Callable[[int, float], SelfAttentionBase], ffn_target: Callable[[int, float], Module], norm_bias: bool, dropout: float)[source]
Bases:
Module
Basic transformer block.
- forward(x: Tensor, mask_data: Tensor | None, *, mask_key: str | None)[source]
- Parameters:
query – Token hidden states; shape (B, L, E)
mask_data – Mask data used for constructing the attention mask Shape depends on the type of mask. See documentation of the SelfAttention subclass used.
- Returns:
Attention values; shape (B, L, E)
- Shape legend:
B: Batch size L: Sequence length E: Embedding dimension
- class tmrc.core.models.components.decoder.CausalSelfAttention(*args, **kwargs)[source]
Bases:
SelfAttentionFlash
Self-attention layer that uses a causal mask.
- Input:
query: Token hidden states; shape (B, L, E) mask_data (ignored): Causal mask does not need any input data
- Output:
Attention values; shape (B, L, E)
- Shape legend:
B: Batch size L: Sequence length E: Embedding dimension
- class tmrc.core.models.components.decoder.DocumentCausalSelfAttention(*args, **kwargs)[source]
Bases:
SelfAttentionFlex
Self-attention layer that uses a document-causal mask.
- Input:
query: Token hidden states; shape (B, L, E) mask_data: Document ids for the input tokens; shape (B, L)
- Output:
Attention values; shape (B, L, E)
- Shape legend:
B: Batch size L: Sequence length E: Embedding dimension
- class tmrc.core.models.components.decoder.MLP(input_size: int, intermediate_size: int, activation: ~typing.Callable[[], ~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>, bias: bool = True, dropout: float = 0.0)[source]
Bases:
Module
Basic MLP block.
- class tmrc.core.models.components.decoder.SelfAttentionBase(hidden_size: int, num_attention_heads: int, qkv_bias: bool, proj_bias: bool, dropout: float)[source]
Bases:
Generic
[T
],Module
Base class for all self-attention implementations.
- All subclasses should be one of two forms:
- A class for a specific attention computation algorithm
(e.g. FlashAttention and FlexAttention). These should inherit directly from SelfAttentionBase and implement the compute_attention function.
- A class for a specific mask. These should inherit from
one of the subclasses of type 1, depending on which attention algorithm is optimal for that mask. These should implement the _generate_mask function.
Each subclass contains a mask cache (denoted MASK_CACHE) so that masks can be shared by all instances of the class. This is useful when a mask is computationally-intensive to construct, such as the block masks used for FlexAttention.
One can cache a mask by specifying a mask_key in the forward. If the mask_key is already in the cache, then the mask with that key in the cache will be reused. Otherwise, a new mask will be constructed and stored in the cache.
The mask cache should regularly be cleared to avoid memory problems. It is up to the user to call the clear_mask_cache function to free up memory.
- compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: T | None = None) Tensor [source]
Handles the attention calculation. Should be implemented in all subclasses of type 1 (see SelfAttentionBase class description).
- Parameters:
query – Query tensor; shape (B, H, L, E)
key – Key tensor; shape (B, H, L, E)
value – Value tensor; shape (B, H, L, E)
attn_mask – Mask to use for the attention calculation. 1 means to attend to the token and 0 means to not attend.
- Returns:
Attention output; shape (B, H, L, E)
- Shape legend:
B: Batch size H: Number of heads L: Sequence length E: Embedding dimension per head
- forward(tokens: Tensor, mask_data: Tensor | None = None, *, mask_key: str | None = None)[source]
- Parameters:
query – Token hidden states; shape (B, L, E)
mask_data – Mask data needed to construct the mask; shape depends on mask
- Returns:
Attention values; shape (B, L, E)
- Shape legend:
B: Batch size L: Sequence length E: Embedding dimension
- generate_mask(mask_data: Tensor, *, mask_key: str | None = None) T | None [source]
Generates the mask for the attention calcuation based on the given input mask. If mask_key is provided, then the attention computation will reuse the mask in the cache with the given key if it is available. If it is not available, then a new mask will be constructed and stored in the cache with that key.
- Parameters:
mask_data – Data describing the mask to construct.
mask_key – Key to use for the mask cache.
- class tmrc.core.models.components.decoder.SelfAttentionFlash(*args, is_causal=True, **kwargs)[source]
Bases:
SelfAttentionBase
[Tensor
]Self-attention layer that uses FlashAttention for the attention computation. Should be used for masks where FlashAttention provides optimal performance.
- class tmrc.core.models.components.decoder.SelfAttentionFlex(*args, **kwargs)[source]
Bases:
SelfAttentionBase
[BlockMask
]Self-attention layer that uses FlexAttention for the attention computation. Should be used for masks where FlexAttention provides optimal performance.
- class tmrc.core.models.components.decoder.SelfAttentionManual(hidden_size: int, num_attention_heads: int, qkv_bias: bool, proj_bias: bool, dropout: float)[source]
Bases:
SelfAttentionBase
[Tensor
]Self-attention layer that uses a manual attention computation. This module is here as an example. Users should use SelfAttentionFlash or SelfAttentionFlex for optimal performance.
- compute_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Tensor | None = None) Tensor [source]
Handles the attention calculation. Should be implemented in all subclasses of type 1 (see SelfAttentionBase class description).
- Parameters:
query – Query tensor; shape (B, H, L, E)
key – Key tensor; shape (B, H, L, E)
value – Value tensor; shape (B, H, L, E)
attn_mask – Mask to use for the attention calculation. 1 means to attend to the token and 0 means to not attend.
- Returns:
Attention output; shape (B, H, L, E)
- Shape legend:
B: Batch size H: Number of heads L: Sequence length E: Embedding dimension per head
- class tmrc.core.models.components.decoder.SwiGLUFFN(input_size: int, intermediate_size: int, bias: bool = True, norm_layer: Callable[[int], Module] | None = None)[source]
Bases:
Module
Feed-forward block using SwiGLU
- class tmrc.core.models.components.quantizer.VectorQuantizer(num_tokens: int, embedding_size: int, beta: float)[source]
Bases:
Module
Handles quantizing a continuous representation into a discrete one. Follows the implementation described in the VQ-VAE paper (https://arxiv.org/abs/1711.00937).
- tmrc.core.training.data.create_dataloaders(datasets_path: str, context_length: int, batch_size: int, platform: Platform)[source]
Note: until tatm supports validation set, we will use the first config.training.batch_size samples as validation set
[TODO] implement proper validation set
- [TODO] more data checks to validate batch sizes, document ids, etc.
are consistent across all batches
- Args:
datasets_path: Location of the datasets to use context_length: Context length of the model to train batch_size: Batch size for training
- class tmrc.core.training.train.ProfilerParams(wait: int, warmup: int, active: int, repeat: int)[source]
Bases:
object
- class tmrc.core.training.train.TrainingParams(train_steps: int, epochs: int, autocast_precision: str, val_interval: int, artifacts_path: str, save_model: bool, log_interval: int)[source]
Bases:
object
- tmrc.core.training.train.get_dist_model(model, config, platform)[source]
- Parameters:
model – Model to wrap for distributed training
config – Hydra config to get the distributed strategy
platform – Platform describing distributed settings and devices
- tmrc.core.training.train.init_wandb(project: str, config: dict)[source]
- Parameters:
project – Name of the wandb project
config – Training config to be show for the current run on wandb
- tmrc.core.training.train.log_model_info(model: Module)[source]
- Parameters:
model – Model to print the info of for debugging
- tmrc.core.training.train.save_model_periodic(model: Module, save_dir: str, interval: int, stop_event: Event)[source]
Thread function to periodically save model
- tmrc.core.training.train.train(model: Module, platform: Platform, train_loader: DataLoader, val_loader: DataLoader, optimizer: Optimizer, training_params: TrainingParams, train_sampler: DistributedSampler | None = None, profiler_params: ProfilerParams | None = None)[source]
- Parameters:
model – Model to train
platform – Platform describing devices to train with
train_loader – Data loader for the training data
val_loader – Data loader for the validation data
optimizer – Optimizer for training
training_params – Training-related parameters detailing training steps, checkpointing, and logging
wandb_project – Name of the project on wandb
wandb_config – Training config to show for the current run on wandb
profiling_params – Profiling-related parameters. Handles the tracing schedule
- class tmrc.core.utils.platform.Platform[source]
Bases:
object
A basic singleton platform abstraction to manage distributed training. Or, if we are not running on multiple nodes, identifies the single GPU device.
- tmrc.core.utils.registry.register_optimizer(key=None)
Model registry