"""Metrics collection, accumulation, and reporting.
MetricsTracker aggregates per-step metrics (loss, grad norm, throughput,
MFU, memory) and dispatches them to configured logging backends (stdout,
WandB, TensorBoard) at a configurable interval.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any
from kempnerforge.config.schema import JobConfig, MetricsConfig
from kempnerforge.metrics.logger import format_metrics, get_logger
from kempnerforge.metrics.memory import get_memory_stats, get_memory_utilization
from kempnerforge.metrics.mfu import compute_mfu, get_gpu_peak_tflops
logger = get_logger(__name__)
[docs]
@dataclass
class StepMetrics:
"""Metrics for a single training step."""
loss: float = 0.0
grad_norm: float = 0.0
lr: float = 0.0
tokens_per_sec: float = 0.0
mfu: float = 0.0
step_time_sec: float = 0.0
allocated_gb: float = 0.0
peak_gb: float = 0.0
reserved_gb: float = 0.0
total_gb: float = 0.0
mem_utilization: float = 0.0
[docs]
class MetricsTracker:
"""Collects, smooths, and reports training metrics.
Timing is handled internally — call ``start_step()`` before and
``end_step()`` after each training step. Metrics are logged to
all configured backends at the configured interval.
Args:
config: Full job config (used for MFU calculation and backend selection).
num_gpus: Number of GPUs for MFU denominator.
gpu_peak_tflops: Per-GPU peak TFLOPS. If None, auto-detected.
"""
[docs]
def __init__(
self,
config: JobConfig,
num_gpus: int = 1,
gpu_peak_tflops: float | None = None,
) -> None:
self.metrics_config = config.metrics
self.model_config = config.model
self.seq_len = config.train.seq_len
self.num_gpus = num_gpus
self.gpu_peak_tflops = gpu_peak_tflops or get_gpu_peak_tflops()
# Smoothed metrics (exponential moving average)
self._ema_alpha = 0.1
self._smoothed: dict[str, float] = {}
# Per-step timing
self._step_start: float = 0.0
# Logging backends (initialized lazily)
self._backends: list[_LoggingBackend] = []
self._backends_initialized = False
def _init_backends(self, config: JobConfig) -> None:
"""Lazily initialize logging backends (rank 0 only)."""
if self._backends_initialized:
return
self._backends_initialized = True
import torch.distributed as dist
if dist.is_initialized() and dist.get_rank() != 0:
return
mc = config.metrics
if mc.enable_wandb:
self._backends.append(WandBBackend(mc))
if mc.enable_tensorboard:
self._backends.append(TensorBoardBackend(mc))
[docs]
def start_step(self) -> None:
"""Mark the beginning of a training step."""
self._step_start = time.perf_counter()
[docs]
def end_step(
self,
step: int,
loss: float,
grad_norm: float,
lr: float,
tokens_in_step: int,
) -> StepMetrics | None:
"""Mark the end of a training step and optionally log metrics.
Args:
step: Current training step number.
loss: Loss value for this step.
grad_norm: Gradient norm (after clipping).
lr: Current learning rate.
tokens_in_step: Total tokens processed in this step (across all GPUs).
Returns:
StepMetrics if this step was a logging step, None otherwise.
"""
step_time = time.perf_counter() - self._step_start
tokens_per_sec = tokens_in_step / step_time if step_time > 0 else 0.0
# Compute MFU
mfu = compute_mfu(
self.model_config,
tokens_per_sec=tokens_per_sec,
num_gpus=self.num_gpus,
gpu_peak_tflops=self.gpu_peak_tflops,
seq_len=self.seq_len,
)
# Memory stats
mem_stats = get_memory_stats()
mem_util = get_memory_utilization()
metrics = StepMetrics(
loss=loss,
grad_norm=grad_norm,
lr=lr,
tokens_per_sec=tokens_per_sec,
mfu=mfu,
step_time_sec=step_time,
allocated_gb=mem_stats["allocated_gb"],
peak_gb=mem_stats["peak_gb"],
reserved_gb=mem_stats["reserved_gb"],
total_gb=mem_stats["total_gb"],
mem_utilization=mem_util,
)
# Update smoothed metrics
self._update_smoothed("loss", loss)
self._update_smoothed("tokens_per_sec", tokens_per_sec)
self._update_smoothed("mfu", mfu)
self._update_smoothed("step_time", step_time)
# Log at interval
if step % self.metrics_config.log_interval == 0 or step == 1:
self._log_step(step, metrics)
return metrics
return None
def _update_smoothed(self, key: str, value: float) -> None:
"""Update exponential moving average for a metric."""
if key not in self._smoothed:
self._smoothed[key] = value
else:
alpha = self._ema_alpha
self._smoothed[key] = alpha * value + (1 - alpha) * self._smoothed[key]
def _log_step(self, step: int, metrics: StepMetrics) -> None:
"""Log metrics to stdout and all backends."""
# Stdout logging
log_dict: dict[str, str | float | int] = {
"loss": f"{metrics.loss:.4f}",
"lr": f"{metrics.lr:.2e}",
"grad_norm": f"{metrics.grad_norm:.3f}",
"tok/s": f"{metrics.tokens_per_sec:,.0f}",
"mfu": f"{metrics.mfu:.1%}",
"mem": (f"{metrics.peak_gb:.1f}/{metrics.total_gb:.0f}GB"),
"step_time": f"{metrics.step_time_sec:.2f}s",
}
logger.info(format_metrics(step, log_dict))
# Backend logging (numeric dict)
backend_dict = {
"train/loss": metrics.loss,
"train/grad_norm": metrics.grad_norm,
"train/lr": metrics.lr,
"train/tokens_per_sec": metrics.tokens_per_sec,
"train/mfu": metrics.mfu,
"train/step_time_sec": metrics.step_time_sec,
"gpu/allocated_gb": metrics.allocated_gb,
"gpu/peak_gb": metrics.peak_gb,
"gpu/reserved_gb": metrics.reserved_gb,
"gpu/mem_utilization": metrics.mem_utilization,
}
# Smoothed metrics
for key, val in self._smoothed.items():
backend_dict[f"smoothed/{key}"] = val
for backend in self._backends:
backend.log(backend_dict, step=step)
[docs]
def log_eval(self, metrics: dict[str, float], step: int) -> None:
"""Log eval metrics to all backends and stdout."""
logger.info(format_metrics(step, metrics)) # type: ignore[reportArgumentType]
for backend in self._backends:
backend.log(metrics, step=step)
[docs]
def init_backends(self, config: JobConfig) -> None:
"""Initialize logging backends (call after distributed setup)."""
self._init_backends(config)
[docs]
def close(self) -> None:
"""Flush and close all logging backends."""
for backend in self._backends:
backend.close()
# ---------------------------------------------------------------------------
# Logging backends
# ---------------------------------------------------------------------------
class _LoggingBackend:
"""Base class for metrics logging backends."""
def log(self, metrics: dict[str, float], step: int) -> None:
raise NotImplementedError
def close(self) -> None:
pass
[docs]
class WandBBackend(_LoggingBackend):
"""Weights & Biases logging backend.
Initializes a WandB run on first log call.
"""
[docs]
def __init__(self, config: MetricsConfig) -> None:
self._config = config
self._run = None
def _ensure_init(self) -> None:
if self._run is not None:
return
try:
import wandb
init_kwargs: dict[str, Any] = {
"project": self._config.wandb_project,
"name": self._config.wandb_run_name,
"resume": "allow",
}
if self._config.wandb_run_id:
init_kwargs["id"] = self._config.wandb_run_id
self._run = wandb.init(**init_kwargs)
self._config.wandb_run_id = self._run.id
logger.info(f"WandB initialized: {self._run.url}")
except ImportError:
logger.warning("wandb not installed — disabling WandB backend")
self._run = False # Sentinel: tried and failed
except Exception as e: # wandb.init() can raise many third-party errors (network, auth)
logger.warning(f"WandB init failed: {e}")
self._run = False
[docs]
def log(self, metrics: dict[str, float], step: int) -> None:
self._ensure_init()
if self._run is False:
return
import wandb
wandb.log(metrics, step=step)
[docs]
def close(self) -> None:
if self._run and self._run is not False:
import wandb
wandb.finish()
[docs]
class TensorBoardBackend(_LoggingBackend):
"""TensorBoard logging backend."""
[docs]
def __init__(self, config: MetricsConfig) -> None:
self._config = config
self._writer = None
def _ensure_init(self) -> None:
if self._writer is not None:
return
try:
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir=self._config.tensorboard_dir)
logger.info(f"TensorBoard writer → {self._config.tensorboard_dir}")
except ImportError:
logger.warning("tensorboard not installed — disabling TensorBoard backend")
self._writer = False
[docs]
def log(self, metrics: dict[str, float], step: int) -> None:
self._ensure_init()
if self._writer is False:
return
for key, val in metrics.items():
self._writer.add_scalar(key, val, global_step=step) # type: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue]
[docs]
def close(self) -> None:
if self._writer and self._writer is not False:
self._writer.close() # type: ignore[reportAttributeAccessIssue]