Source code for kempnerforge.metrics.memory
"""GPU memory tracking and reporting.
Provides utilities for monitoring GPU memory usage during training:
- Current / peak / reserved memory
- Memory utilization as a percentage of total
- Human-readable formatting
"""
from __future__ import annotations
import logging
import torch
logger = logging.getLogger(__name__)
[docs]
def get_memory_stats(device: int = 0) -> dict[str, float]:
"""Get current GPU memory statistics in GB.
Args:
device: CUDA device index.
Returns:
Dict with allocated, peak, reserved, and total memory in GB.
"""
if not torch.cuda.is_available():
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0, "total_gb": 0}
gb = 1024**3
return {
"allocated_gb": torch.cuda.memory_allocated(device) / gb,
"peak_gb": torch.cuda.max_memory_allocated(device) / gb,
"reserved_gb": torch.cuda.memory_reserved(device) / gb,
"total_gb": torch.cuda.get_device_properties(device).total_memory / gb,
}
[docs]
def get_memory_utilization(device: int = 0) -> float:
"""Get peak memory utilization as a fraction of total GPU memory.
Returns:
Utilization between 0.0 and 1.0.
"""
stats = get_memory_stats(device)
if stats["total_gb"] == 0:
return 0.0
return stats["peak_gb"] / stats["total_gb"]
[docs]
def reset_peak_memory(device: int = 0) -> None:
"""Reset peak memory tracking counter."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
# ---------------------------------------------------------------------------
# DeviceMemoryMonitor — per-step tracking with snapshot support
# ---------------------------------------------------------------------------
[docs]
class DeviceMemoryMonitor:
"""Tracks GPU memory usage across training steps.
Resets peak memory stats at each reporting interval so that the
peak reflects per-interval usage rather than all-time peak.
Supports memory snapshot capture at a configurable step for
debugging OOM and memory fragmentation with pytorch.org/memory_viz.
Args:
device: CUDA device index.
snapshot_step: Step at which to capture a memory snapshot. None to disable.
snapshot_dir: Directory to save snapshots.
"""
[docs]
def __init__(
self,
device: int = 0,
snapshot_step: int | None = None,
snapshot_dir: str = "memory_snapshots",
) -> None:
self.device = device
self.snapshot_step = snapshot_step
self.snapshot_dir = snapshot_dir
self._snapshot_taken = False
[docs]
def report(self, step: int) -> dict[str, float]:
"""Report memory stats for the current interval and reset peak.
Args:
step: Current training step.
Returns:
Dict with memory stats.
"""
stats = get_memory_stats(self.device)
util = get_memory_utilization(self.device)
stats["mem_utilization"] = util
logger.info(f"[step {step}] {format_memory_stats(self.device)}")
# Check if we should capture a snapshot
if (
self.snapshot_step is not None
and step == self.snapshot_step
and not self._snapshot_taken
):
self.capture_snapshot(step)
# Reset peak for next interval
reset_peak_memory(self.device)
return stats
[docs]
def capture_snapshot(self, step: int) -> str | None:
"""Capture a CUDA memory snapshot and save as pickle.
The snapshot can be visualized at https://pytorch.org/memory_viz
Args:
step: Current step (used in filename).
Returns:
Path to the saved snapshot, or None if capture failed.
"""
if not torch.cuda.is_available():
return None
try:
from pathlib import Path
snapshot_dir = Path(self.snapshot_dir)
snapshot_dir.mkdir(parents=True, exist_ok=True)
# Record and export memory snapshot
torch.cuda.memory._record_memory_history()
# Give it a moment to record current state
torch.cuda.synchronize(self.device)
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._record_memory_history(enabled=None)
import pickle
path = snapshot_dir / f"snapshot_step_{step}_device_{self.device}.pickle"
with open(path, "wb") as f:
pickle.dump(snapshot, f)
self._snapshot_taken = True
logger.info(f"Memory snapshot saved: {path}")
return str(path)
except Exception as e: # best-effort diagnostic; CUDA, disk, or pickle errors
logger.warning(f"Memory snapshot failed: {e}")
return None