Source code for kempnerforge.distributed.parallel

"""Parallelism application: TP, AC, Float8, FSDP2, and model building.

Applies parallelism to Transformer models in the correct order.

Application order (critical — wrong order causes silent correctness bugs):
  1. Tensor parallelism (apply_tensor_parallel) — must see raw blocks
  2. Expert parallelism (apply_expert_parallel) — partitions MoE experts
  3. Float8 training (apply_float8) — wraps Linear → Float8Linear
  4. Activation checkpointing (apply_ac) — wraps blocks in CheckpointWrapper
  5. FSDP2 (apply_fsdp2) — shards everything (uses float8 all-gather if enabled)

For convenience, ``build_parallel_model`` combines all steps including
model creation, meta-device initialization, and optional torch.compile.
"""

from __future__ import annotations

import logging
from functools import partial

import torch
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh

from kempnerforge.config.registry import registry
from kempnerforge.config.schema import ActivationCheckpointing
from kempnerforge.model.transformer import Transformer, TransformerBlock

logger = logging.getLogger(__name__)


[docs] def has_dp_mesh(device_mesh: DeviceMesh) -> bool: """Check whether the DeviceMesh contains any data-parallel dimensions.""" dim_names = device_mesh.mesh_dim_names return "dp_shard" in dim_names or "dp_replicate" in dim_names # type: ignore[reportOperatorIssue]
[docs] def get_dp_mesh(device_mesh: DeviceMesh) -> DeviceMesh: """Extract the data-parallel sub-mesh from a DeviceMesh. Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present. Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use ``has_dp_mesh`` to check first. """ dim_names = device_mesh.mesh_dim_names has_replicate = "dp_replicate" in dim_names # type: ignore[reportOperatorIssue] has_shard = "dp_shard" in dim_names # type: ignore[reportOperatorIssue] if has_replicate and has_shard: # 2D HSDP: first dim = replicate, second dim = shard return device_mesh["dp_replicate", "dp_shard"] elif has_shard: return device_mesh["dp_shard"] elif has_replicate: return device_mesh["dp_replicate"] else: raise ValueError( f"No DP dimensions in mesh {dim_names}. " "Use has_dp_mesh() to check before calling get_dp_mesh()." )
[docs] def default_mp_policy(param_dtype: torch.dtype = torch.bfloat16) -> MixedPrecisionPolicy: """Mixed-precision policy: param_dtype compute, fp32 gradient reduction.""" return MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=torch.float32, )
[docs] def apply_ac(model: Transformer, mode: ActivationCheckpointing) -> None: """Apply activation checkpointing to the model. Must be called BEFORE apply_fsdp2. Args: model: Transformer model. mode: Checkpointing mode — "none", "full", or "selective". full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off). """ if mode == ActivationCheckpointing.none: return wrapper_fn = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) if mode == ActivationCheckpointing.full: apply_activation_checkpointing( model, checkpoint_wrapper_fn=wrapper_fn, check_fn=lambda m: isinstance(m, TransformerBlock), ) logger.info("Applied full activation checkpointing (per TransformerBlock)") elif mode == ActivationCheckpointing.selective: from kempnerforge.model.attention import Attention apply_activation_checkpointing( model, checkpoint_wrapper_fn=wrapper_fn, check_fn=lambda m: isinstance(m, Attention), ) logger.info("Applied selective activation checkpointing (Attention only)")
[docs] def apply_float8(model: Transformer, enable_fsdp_float8_all_gather: bool = True) -> None: """Apply Float8 training (torchao) to the model. Converts nn.Linear modules to Float8Linear for E4M3 forward / E5M2 backward with dynamic tensorwise scaling. Master weights remain in bf16. Must be called AFTER apply_tensor_parallel / apply_expert_parallel and BEFORE apply_ac / apply_fsdp2. MoE expert modules (experts and shared_expert) are excluded because they use grouped GEMM (torch._grouped_mm) which bypasses Float8Linear.forward(). Args: model: Transformer model. enable_fsdp_float8_all_gather: If True, FSDP2 all-gathers use float8 (halves communication volume). Requires FSDP2 to be applied after. Must be False when TP is active — the float8 weight wrapper calls aten.is_pinned on DTensors, which has no sharding strategy yet. """ import dataclasses from torchao.float8 import ( Float8LinearConfig, Float8LinearRecipeName, convert_to_float8_training, ) config = dataclasses.replace( Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.TENSORWISE), enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, ) def _filter_fn(module: torch.nn.Module, fqn: str) -> bool: """Skip modules that shouldn't use Float8. Excluded: - Expert Linears: use grouped GEMM (torch._grouped_mm), not Linear.forward() - Router gate: small output dim (num_experts) often not divisible by 16, which torch._scaled_mm requires. Also not compute-bound. """ if "experts" in fqn or "shared_expert" in fqn: return False return "router" not in fqn convert_to_float8_training(model, config=config, module_filter_fn=_filter_fn) logger.info( f"Applied Float8 training: recipe=TENSORWISE, " f"fsdp_float8_all_gather={enable_fsdp_float8_all_gather}" )
def _has_ep_moe(module: torch.nn.Module) -> bool: """Check if a module contains MoE with expert parallelism active.""" from kempnerforge.model.moe import MoEMLP return any(isinstance(m, MoEMLP) and m.ep_world_size > 1 for m in module.modules())
[docs] def apply_fsdp2( model: Transformer, device_mesh: DeviceMesh, mp_policy: MixedPrecisionPolicy | None = None, reshard_after_forward: bool | int = True, ) -> None: """Apply FSDP2 (fully_shard) to a Transformer model. Shards each TransformerBlock independently, then wraps the top-level model for remaining parameters (embeddings, final norm, output head). Must be called AFTER apply_ac and apply_tensor_parallel. **EP interaction**: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2's reduce-scatter to fire between EP's backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free. Args: model: Transformer model to shard. device_mesh: Full DeviceMesh (dp sub-mesh is extracted automatically). mp_policy: Mixed precision policy. Defaults to bf16 params + fp32 reduce. reshard_after_forward: Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers. """ if not has_dp_mesh(device_mesh): logger.info("No DP dimensions in mesh — skipping FSDP2") return dp_mesh = get_dp_mesh(device_mesh) policy = mp_policy or default_mp_policy() # Shard each transformer block as an independent FSDP unit. # # EP-MoE blocks get per-sub-module wrapping: attention and MoE (layer.mlp) # are each individually fully_shard()'d. This avoids the deadlock that # per-block wrapping would cause (FSDP2 reduce-scatter firing between EP's # two backward all-to-all calls), while still getting per-block param # resharding and gradient reduction overlap — unlike the previous approach # of deferring all MoE params to the top-level wrap. # # Safety: FSDP2 bucketizes all params in a fully_shard() unit and fires # reduce-scatter only after the last param's grad is computed. For # fully_shard(layer.mlp), reduce-scatter fires after the entire MoE # backward (after both EP all-to-all calls complete). All dp_shard peers # share the same EP rank, so they reach reduce-scatter in the same phase. ep_sub_wrapped = 0 for layer in model.layers.values(): if _has_ep_moe(layer): fully_shard( # type: ignore[reportCallIssue] layer.attention, # type: ignore[reportArgumentType] mesh=dp_mesh, mp_policy=policy, reshard_after_forward=reshard_after_forward, ) fully_shard( # type: ignore[reportCallIssue] layer.mlp, # type: ignore[reportArgumentType] mesh=dp_mesh, mp_policy=policy, reshard_after_forward=reshard_after_forward, ) ep_sub_wrapped += 1 continue fully_shard( layer, mesh=dp_mesh, mp_policy=policy, reshard_after_forward=reshard_after_forward, ) # Top-level shard covers remaining params (embeddings, final norm, output # head, and layer norms from EP-MoE blocks). fully_shard( model, mesh=dp_mesh, mp_policy=policy, reshard_after_forward=reshard_after_forward, ) logger.info( f"Applied FSDP2: dp_mesh={dp_mesh.mesh_dim_names}, blocks={len(model.layers)}, " f"mp={policy}" + (f", ep_moe_blocks_sub_wrapped={ep_sub_wrapped}" if ep_sub_wrapped else "") )
[docs] def build_parallel_model( model_config, device: torch.device, device_mesh: DeviceMesh | None, *, ac_mode: ActivationCheckpointing = ActivationCheckpointing.none, mp_policy: MixedPrecisionPolicy | None = None, param_dtype: torch.dtype = torch.bfloat16, compile_model: bool = False, fp8: bool = False, ) -> torch.nn.Module: """Build a Transformer with parallelism applied in the correct order. Handles four configurations automatically: - TP enabled: meta-device init → TP → EP → [Float8] → AC → FSDP → materialize - TP disabled: create on device → EP → [Float8] → AC → FSDP This is the non-PP model building path. For pipeline parallelism, use ``build_stage_module`` + apply parallelism directly. Args: model_config: ModelConfig for the Transformer. device: Target device for the model. device_mesh: Full DeviceMesh (may contain tp, dp_shard, dp_replicate dims). ac_mode: Activation checkpointing mode. mp_policy: FSDP2 mixed-precision policy. Defaults to bf16 params + fp32 reduce. param_dtype: Dtype for model parameters. compile_model: Whether to torch.compile the model. fp8: Whether to enable Float8 mixed precision (torchao). Returns: The parallelized model, ready for training. """ from kempnerforge.distributed.tensor_parallel import apply_tensor_parallel tp_enabled = device_mesh is not None and "tp" in device_mesh.mesh_dim_names # type: ignore[reportOperatorIssue] model_builder = registry.get_model(model_config.model_type) from kempnerforge.distributed.expert_parallel import apply_expert_parallel if tp_enabled: # Meta-device init: create model with zero memory, apply parallelisms, # then materialize only the local shards on GPU. with torch.device("meta"): model = model_builder(model_config) apply_tensor_parallel(model, device_mesh) # type: ignore[reportArgumentType] apply_expert_parallel(model, device_mesh) if fp8: apply_float8(model) apply_ac(model, ac_mode) if device_mesh is not None: apply_fsdp2(model, device_mesh, mp_policy=mp_policy) model.to_empty(device=device) model.init_weights_and_freqs() model.to(dtype=param_dtype) else: model = model_builder(model_config).to(device=device, dtype=param_dtype) apply_expert_parallel(model, device_mesh) if fp8: apply_float8(model) apply_ac(model, ac_mode) if device_mesh is not None: apply_fsdp2(model, device_mesh, mp_policy=mp_policy) if compile_model: logger.info("Compiling model with torch.compile...") model = torch.compile(model) n_params = sum(p.numel() for p in model.parameters()) logger.info(f"Model: {n_params:,} parameters") return model # type: ignore[reportReturnType]