Training Configuration ====================== The training configuration is managed through a YAML file with several main sections: ``datasets``, ``model``, ``optimizer``, ``tokenizer``, ``training``, ``profiler``, and ``logging`` settings. Dataset Configuration --------------------- Configuration related to input data and tokenization. .. code-block:: yaml datasets: name: algebraic-stack path: "/path/to/data" tokenizer_used: t5-base * ``name``: Name of the dataset being used (that was tokenized by the ``tatm`` package) * ``path``: File system path to the tokenized dataset * ``tokenizer_used``: The tokenizer that was used to preprocess the data Model Configuration ------------------- Parameters that define the model architecture. .. code-block:: yaml model: name: gpt n_head: 4 d_model: 512 n_layer: 8 dropout_p: 0.0 context_length: 512 autocast_precision: bfloat16 mlp_scale_factor: 4 mlp_bias: True attn_bias: False proj_bias: True ln_bias: True cls_head_bias: True activation: relu mask: causal_document * ``name``: Model architecture type (currently supports 'gpt') * ``n_head``: Number of attention heads * ``d_model``: Hidden dimension size * ``n_layer``: Number of transformer layers * ``dropout_p``: Dropout probability (0.0 means no dropout) * ``context_length``: Maximum sequence length for input tokens * ``autocast_precision``: Precision for automatic mixed precision training (options: ``float32``, ``float16``, ``bfloat16``) * ``mlp_scale_factor``: Multiplier for MLP hidden dimension relative to ``d_model`` * ``mlp_bias``: Include bias terms in MLP layers * ``attn_bias``: Include bias terms in attention computation * ``proj_bias``: Include bias terms in projection layers * ``ln_bias``: Include bias terms in layer normalization * ``cls_head_bias``: Include bias terms in classification head * ``activation``: Activation function (options: ``relu``, ``gelu``) * ``mask``: Attention mask type (e.g., ``causal_document`` to use causal attention + document masking) Optimizer Configuration ----------------------- Parameters for the optimization algorithm. .. code-block:: yaml optimizer: name: AdamW lr: 0.0001 weight_decay: 0.01 betas: [0.9, 0.999] eps: 1e-8 precision: float32 * ``name``: Optimizer type (currently supports 'AdamW') * ``lr``: Learning rate * ``weight_decay``: L2 regularization factor * ``betas``: Beta parameters for AdamW [β1, β2] * ``eps``: Epsilon parameter for numerical stability * ``precision``: Optimizer state precision Tokenizer Configuration ----------------------- Settings for the tokenizer. .. code-block:: yaml tokenizer: name: t5-base vocab_size: 32128 * ``name``: Name of the pretrained tokenizer * ``vocab_size``: Size of the vocabulary Training Configuration ---------------------- Parameters controlling the training process. .. code-block:: yaml training: epochs: 1 train_steps: 100000 batch_size: 256 log_interval: 20 val_interval: 100 shuffle: True save_model: True save_every: 3600 artifacts_path: /path/to/artifacts use_oracle: False torch_profiling: False distributed_strategy: ddp * ``epochs``: Number of training epochs * ``train_steps``: Maximum number of training steps (training stops at whichever comes first: epochs or train_steps) * ``batch_size``: Size of training batches * ``log_interval``: Number of steps between logging updates * ``val_interval``: Number of steps between validation checks * ``shuffle``: Whether to shuffle the dataset between epochs * ``save_model``: Whether to save model checkpoints * ``save_every``: Checkpoint saving frequency in seconds (3600 = once an hour) * ``artifacts_path``: Directory to save model checkpoints and other artifacts * ``use_oracle``: Enable oracle mode for debugging/testing * ``torch_profiling``: Enable PyTorch profiler for performance analysis * ``distributed_strategy``: Distributed training strategy (options: ``ddp``, ``fsdp``) FSDP Configuration ------------------- Configuration for Fully Sharded Data Parallel (FSDP) training. .. code-block:: yaml fsdp: mixed_precision: True * ``mixed_precision``: Enable mixed precision training for FSDP (recommended for large models) Profiler Configuration ---------------------- Parameters controlling the torch profiler. .. code-block:: yaml profiler: wait: 1 warmup: 3 active: 1 repeat: 1 * ``wait``: Number of steps to wait before starting profiling * ``warmup``: Number of steps to warm up before profiling * ``active``: Number of steps to actively profile * ``repeat``: Number of times to repeat the profiling .. note:: ``repeat : 0`` repeats the profiling until the end of the training (not recommended). To disable torch profiler set ``torch_profiling: False`` in the training configuration. Logging Configuration --------------------- Settings for experiment tracking. .. code-block:: yaml wandb_log: name: tmrc_log * ``name``: Run name for Weights & Biases logging Hydra Configuration ------------------- Settings for Hydra configuration management. .. code-block:: yaml HydraConf: version_base: "1.1" * ``version_base``: Hydra version compatibility setting