kempnerforge.distributed

Distributed training infrastructure.

Public API:

  • init_distributed / destroy_distributed: process group lifecycle

  • apply_fsdp2 / apply_ac: FSDP2 sharding and activation checkpointing

  • apply_tensor_parallel: DTensor-based tensor parallelism

  • Pipeline parallelism: build_stage_module, build_pipeline_stage, build_pipeline_schedule

  • clip_grad_norm_: DTensor-aware gradient clipping

kempnerforge.distributed.apply_ac(model, mode)[source]

Apply activation checkpointing to the model.

Must be called BEFORE apply_fsdp2.

Parameters:
  • model (Transformer) – Transformer model.

  • mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).

Return type:

None

kempnerforge.distributed.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]

Apply FSDP2 (fully_shard) to a Transformer model.

Shards each TransformerBlock independently, then wraps the top-level model for remaining parameters (embeddings, final norm, output head).

Must be called AFTER apply_ac and apply_tensor_parallel.

EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.

Parameters:
  • model (Transformer) – Transformer model to shard.

  • device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).

  • mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.

  • reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.

Return type:

None

kempnerforge.distributed.apply_tensor_parallel(model, device_mesh)[source]

Apply tensor parallelism to a Transformer or PipelineStageModule.

Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.

PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.

Parameters:
Return type:

None

kempnerforge.distributed.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]

Create a pipeline execution schedule.

Parameters:
  • stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.

  • n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.

  • loss_fn (callable) – Loss function (applied on last stage only).

  • schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.

Returns:

  • First stage: schedule.step(input_tensor, target=labels)

  • Other stages: schedule.step(target=labels) or schedule.step()

Return type:

A pipeline schedule object. Call schedule.step() in the training loop

kempnerforge.distributed.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]

Wrap a stage module in a PipelineStage for schedule execution.

Parameters:
Returns:

A PipelineStage ready for use with a pipeline schedule.

Return type:

torch.distributed.pipelining.PipelineStage

kempnerforge.distributed.build_stage_module(config, pp_rank, pp_size)[source]

Build the model chunk for a specific pipeline stage.

Parameters:
  • config (ModelConfig) – Model configuration.

  • pp_rank (int) – This process’s pipeline rank (0-indexed).

  • pp_size (int) – Total number of pipeline stages.

Returns:

A PipelineStageModule containing only the parameters for this stage.

Return type:

PipelineStageModule

kempnerforge.distributed.clip_grad_norm_(model, max_norm, foreach=True)[source]

Clip gradient norm across all parameters.

Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.

Falls back to torch.nn.utils.clip_grad_norm_ when there is only one mesh (pure FSDP or single-GPU).

Parameters:
  • model (torch.nn.Module) – Model whose gradients to clip.

  • max_norm (float) – Maximum gradient norm.

  • foreach (bool) – Use the faster foreach implementation.

Returns:

Total gradient norm (before clipping).

Return type:

torch.Tensor

kempnerforge.distributed.compute_layer_assignment(n_layers, pp_size)[source]

Compute which layers go to which PP stage.

Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.

Parameters:
  • n_layers (int) – Total number of transformer layers.

  • pp_size (int) – Number of pipeline stages.

Returns:

List of (start_layer, end_layer) tuples, one per stage. end_layer is exclusive.

Raises:

ValueError – If pp_size > n_layers.

Return type:

list[tuple[int, int]]

kempnerforge.distributed.default_mp_policy(param_dtype=torch.bfloat16)[source]

Mixed-precision policy: param_dtype compute, fp32 gradient reduction.

Parameters:

param_dtype (torch.dtype)

Return type:

torch.distributed._composable.fsdp.MixedPrecisionPolicy

kempnerforge.distributed.destroy_distributed()[source]

Clean up distributed process groups.

Return type:

None

kempnerforge.distributed.get_dp_mesh(device_mesh)[source]

Extract the data-parallel sub-mesh from a DeviceMesh.

Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.

Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use has_dp_mesh to check first.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

torch.distributed.device_mesh.DeviceMesh

kempnerforge.distributed.get_pp_mesh(device_mesh)[source]

Extract the PP sub-mesh from a DeviceMesh.

Returns None if no ‘pp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.get_pp_rank(device_mesh)[source]

Get the pipeline parallel rank for this process.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.get_pp_size(device_mesh)[source]

Get the pipeline parallel world size.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.get_tp_mesh(device_mesh)[source]

Extract the TP sub-mesh from a DeviceMesh.

Returns None if no ‘tp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.get_world_info()[source]

Return (rank, local_rank, world_size) from environment variables.

Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.

Return type:

tuple[int, int, int]

kempnerforge.distributed.init_distributed(config, seed=42)[source]

Initialize distributed training and build the DeviceMesh.

Parameters:
  • config (DistributedConfig) – Distributed configuration with parallelism dimensions.

  • seed (int) – Random seed for reproducibility.

Returns:

DeviceMesh if world_size > 1, None for single-GPU.

Return type:

DeviceMesh | None

kempnerforge.distributed.is_rank_zero()[source]

Check if the current process is rank 0.

Return type:

bool

Modules

expert_parallel

Expert Parallelism: partition MoE experts across an EP process group.

parallel

Parallelism application: TP, AC, Float8, FSDP2, and model building.

pipeline_parallel

Pipeline parallelism for KempnerForge.

setup

Distributed process group and DeviceMesh initialization.

tensor_parallel

Tensor parallelism sharding plans for model components.

utils

Distributed training utilities.