kempnerforge.metrics.mfu¶
Model FLOPs Utilization (MFU) computation.
Implements the PaLM paper formula for estimating achieved FLOPS relative to hardware peak, with auto-detection of GPU capabilities.
MFU = achieved_tflops / peak_tflops
- Where:
model_flops_per_token = 6*P + 12*L*D*S (forward + backward) achieved_tflops = model_flops_per_token * tokens_per_sec / 1e12
Functions
|
Compute Model FLOPs Utilization. |
|
Estimate FLOPS per token for forward + backward pass. |
|
Auto-detect GPU peak bf16 TFLOPS. |
- kempnerforge.metrics.mfu.get_gpu_peak_tflops(device=0)[source]¶
Auto-detect GPU peak bf16 TFLOPS.
Tries to match the GPU name against known models. Falls back to a conservative estimate based on compute capability.
- kempnerforge.metrics.mfu.estimate_model_flops_per_token(config, seq_len=None)[source]¶
Estimate FLOPS per token for forward + backward pass.
Uses the PaLM paper approximation:
6*P + 12*L*D*SFor MoE: uses active params (top_k experts per layer, not all experts). Excludes embedding (table lookup, not matmul). Includes output projection. The 12*L*D*S attention term does not discount GQA — FlashAttention expands GQA internally, so the hardware performs full attention compute. Router FLOPS (dim × num_experts) are intentionally omitted — negligible.
- Parameters:
config (ModelConfig) – Model configuration.
seq_len (int | None) – Actual training sequence length. Falls back to config.max_seq_len if not provided.
- Returns:
Estimated FLOPS per token.
- Return type:
- kempnerforge.metrics.mfu.compute_mfu(config, tokens_per_sec, num_gpus=1, gpu_peak_tflops=None, seq_len=None)[source]¶
Compute Model FLOPs Utilization.
- Parameters:
config (ModelConfig) – Model configuration.
tokens_per_sec (float) – Global throughput (tokens/sec across all GPUs).
num_gpus (int) – Number of GPUs.
gpu_peak_tflops (float | None) – Peak bf16 TFLOPS per GPU. Auto-detected if None.
seq_len (int | None) – Actual training sequence length for attention FLOPS. Falls back to config.max_seq_len if not provided.
- Returns:
MFU as a fraction (0.0 to 1.0).
- Return type: