Training Configuration
The training configuration is managed through a YAML file with several main sections: datasets
, model
, optimizer
, tokenizer
, training
, profiler
, and logging
settings.
Dataset Configuration
Configuration related to input data and tokenization.
datasets:
name: algebraic-stack
path: "/path/to/data"
tokenizer_used: t5-base
name
: Name of the dataset being used (that was tokenized by thetatm
package)path
: File system path to the tokenized datasettokenizer_used
: The tokenizer that was used to preprocess the data
Model Configuration
Parameters that define the model architecture.
model:
name: gpt
n_head: 4
d_model: 512
n_layer: 8
dropout_p: 0.0
context_length: 512
autocast_precision: bfloat16
mlp_scale_factor: 4
mlp_bias: True
attn_bias: False
proj_bias: True
ln_bias: True
cls_head_bias: True
activation: relu
mask: causal_document
name
: Model architecture type (currently supports ‘gpt’)n_head
: Number of attention headsd_model
: Hidden dimension sizen_layer
: Number of transformer layersdropout_p
: Dropout probability (0.0 means no dropout)context_length
: Maximum sequence length for input tokensautocast_precision
: Precision for automatic mixed precision training (options:float32
,float16
,bfloat16
)mlp_scale_factor
: Multiplier for MLP hidden dimension relative tod_model
mlp_bias
: Include bias terms in MLP layersattn_bias
: Include bias terms in attention computationproj_bias
: Include bias terms in projection layersln_bias
: Include bias terms in layer normalizationcls_head_bias
: Include bias terms in classification headactivation
: Activation function (options:relu
,gelu
)mask
: Attention mask type (e.g.,causal_document
to use causal attention + document masking)
Optimizer Configuration
Parameters for the optimization algorithm.
optimizer:
name: AdamW
lr: 0.0001
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
precision: float32
name
: Optimizer type (currently supports ‘AdamW’)lr
: Learning rateweight_decay
: L2 regularization factorbetas
: Beta parameters for AdamW [β1, β2]eps
: Epsilon parameter for numerical stabilityprecision
: Optimizer state precision
Tokenizer Configuration
Settings for the tokenizer.
tokenizer:
name: t5-base
vocab_size: 32128
name
: Name of the pretrained tokenizervocab_size
: Size of the vocabulary
Training Configuration
Parameters controlling the training process.
training:
epochs: 1
train_steps: 100000
batch_size: 256
log_interval: 20
val_interval: 100
shuffle: True
save_model: True
save_every: 3600
artifacts_path: /path/to/artifacts
use_oracle: False
torch_profiling: False
distributed_strategy: ddp
epochs
: Number of training epochstrain_steps
: Maximum number of training steps (training stops at whichever comes first: epochs or train_steps)batch_size
: Size of training batcheslog_interval
: Number of steps between logging updatesval_interval
: Number of steps between validation checksshuffle
: Whether to shuffle the dataset between epochssave_model
: Whether to save model checkpointssave_every
: Checkpoint saving frequency in seconds (3600 = once an hour)artifacts_path
: Directory to save model checkpoints and other artifactsuse_oracle
: Enable oracle mode for debugging/testingtorch_profiling
: Enable PyTorch profiler for performance analysisdistributed_strategy
: Distributed training strategy (options:ddp
,fsdp
)
FSDP Configuration
Configuration for Fully Sharded Data Parallel (FSDP) training.
fsdp:
mixed_precision: True
mixed_precision
: Enable mixed precision training for FSDP (recommended for large models)
Profiler Configuration
Parameters controlling the torch profiler.
profiler:
wait: 1
warmup: 3
active: 1
repeat: 1
wait
: Number of steps to wait before starting profilingwarmup
: Number of steps to warm up before profilingactive
: Number of steps to actively profilerepeat
: Number of times to repeat the profiling
Note
repeat : 0
repeats the profiling until the end of the training (not recommended). To disable torch profiler set torch_profiling: False
in the training configuration.
Logging Configuration
Settings for experiment tracking.
wandb_log:
name: tmrc_log
name
: Run name for Weights & Biases logging
Hydra Configuration
Settings for Hydra configuration management.
HydraConf:
version_base: "1.1"
version_base
: Hydra version compatibility setting