Source code for kempnerforge.checkpoint.async_save
"""Async checkpointing for non-blocking saves.
Uses ``dcp.async_save()`` to snapshot state to CPU and write to disk
in the background, returning control to the training loop immediately.
Modes:
- disabled: Synchronous save (simple, for debugging).
- async: Standard async via dcp.async_save().
- async_with_pinned_mem: Async with pinned memory staging for faster GPU→CPU.
"""
from __future__ import annotations
import logging
from typing import Any
import torch.distributed.checkpoint as dcp
from kempnerforge.config.schema import AsyncCheckpointMode
logger = logging.getLogger(__name__)
[docs]
class AsyncCheckpointer:
"""Non-blocking checkpoint saver.
Wraps ``dcp.async_save()`` and manages the background save future.
Each new save waits for the previous async save to complete first.
Args:
mode: Checkpoint mode (disabled/async/async_with_pinned_mem).
"""
[docs]
def __init__(self, mode: AsyncCheckpointMode = AsyncCheckpointMode.disabled) -> None:
self.mode = mode
self._pending_future: Any = None
[docs]
def save(self, state_dict: dict, checkpoint_id: str, process_group=None) -> None:
"""Save distributed state, potentially asynchronously.
Args:
state_dict: DCP-compatible state dict (model + optimizer).
checkpoint_id: Checkpoint directory path.
process_group: Process group for DCP. Required for PP where each
stage has a different state dict — pass a group scoped to ranks
within the same PP stage. None uses the default global group.
"""
# Wait for any pending async save to complete first
self.wait()
if self.mode == AsyncCheckpointMode.disabled:
dcp.save(state_dict, checkpoint_id=checkpoint_id, process_group=process_group)
logger.info(f"Sync checkpoint saved: {checkpoint_id}")
elif self.mode in (AsyncCheckpointMode.async_, AsyncCheckpointMode.async_pinned):
self._pending_future = dcp.async_save(
state_dict,
checkpoint_id=checkpoint_id,
process_group=process_group,
)
logger.info(f"Async checkpoint started: {checkpoint_id}")
[docs]
def wait(self) -> None:
"""Block until any pending async save completes."""
if self._pending_future is not None:
self._pending_future.result()
self._pending_future = None
logger.info("Async checkpoint completed")
@property
def is_pending(self) -> bool:
"""Check if an async save is still in progress."""
return self._pending_future is not None