kempnerforge.distributed

Distributed training infrastructure.

Public API:

  • init_distributed / destroy_distributed: process group lifecycle

  • apply_fsdp2 / apply_ac: FSDP2 sharding and activation checkpointing

  • apply_tensor_parallel: DTensor-based tensor parallelism

  • Pipeline parallelism: build_stage_module, build_pipeline_stage, build_pipeline_schedule

  • clip_grad_norm_: DTensor-aware gradient clipping

kempnerforge.distributed.apply_ac(model, mode)[source]

Apply activation checkpointing to the model.

Must be called BEFORE apply_fsdp2.

Parameters:
  • model (Transformer) – Transformer model.

  • mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).

Return type:

None

kempnerforge.distributed.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]

Apply FSDP2 (fully_shard) to a Transformer model.

Shards each TransformerBlock independently (via _fsdp_wrap_transformer_blocks so the EP-MoE per-sub-module wrap is shared with the VLM path), then wraps the top-level model for remaining parameters (embeddings, final norm, output head).

Must be called AFTER apply_ac and apply_tensor_parallel.

EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.

Parameters:
  • model (Transformer) – Transformer model to shard.

  • device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).

  • mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.

  • reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.

Return type:

None

kempnerforge.distributed.apply_tensor_parallel(model, device_mesh)[source]

Apply tensor parallelism to a Transformer or PipelineStageModule.

Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.

PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.

Parameters:
Return type:

None

kempnerforge.distributed.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]

Create a pipeline execution schedule.

Parameters:
  • stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.

  • n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.

  • loss_fn (callable) – Loss function (applied on last stage only).

  • schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.

Returns:

  • First stage: schedule.step(input_tensor, target=labels)

  • Other stages: schedule.step(target=labels) or schedule.step()

Return type:

A pipeline schedule object. Call schedule.step() in the training loop

kempnerforge.distributed.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]

Wrap a stage module in a PipelineStage for schedule execution.

Parameters:
Returns:

A PipelineStage ready for use with a pipeline schedule.

Return type:

torch.distributed.pipelining.PipelineStage

kempnerforge.distributed.build_stage_module(config, pp_rank, pp_size)[source]

Build the model chunk for a specific pipeline stage.

Parameters:
  • config (ModelConfig) – Model configuration.

  • pp_rank (int) – This process’s pipeline rank (0-indexed).

  • pp_size (int) – Total number of pipeline stages.

Returns:

A PipelineStageModule containing only the parameters for this stage.

Return type:

PipelineStageModule

kempnerforge.distributed.clip_grad_norm_(model, max_norm, foreach=True)[source]

Clip gradient norm across all parameters.

Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.

Falls back to torch.nn.utils.clip_grad_norm_ when there is only one mesh (pure FSDP or single-GPU).

Parameters:
  • model (torch.nn.Module) – Model whose gradients to clip.

  • max_norm (float) – Maximum gradient norm.

  • foreach (bool) – Use the faster foreach implementation.

Returns:

Total gradient norm (before clipping).

Return type:

torch.Tensor

kempnerforge.distributed.compute_layer_assignment(n_layers, pp_size)[source]

Compute which layers go to which PP stage.

Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.

Parameters:
  • n_layers (int) – Total number of transformer layers.

  • pp_size (int) – Number of pipeline stages.

Returns:

List of (start_layer, end_layer) tuples, one per stage. end_layer is exclusive.

Raises:

ValueError – If pp_size > n_layers.

Return type:

list[tuple[int, int]]

kempnerforge.distributed.default_mp_policy(param_dtype=torch.bfloat16)[source]

Mixed-precision policy: param_dtype compute, fp32 gradient reduction.

cast_forward_inputs=True ensures FSDP2 casts input tensors to the declared param_dtype at each wrapped module’s forward boundary. The VLM path relies on this so image embeddings produced by the adapter (bf16) reach the sharded transformer with matching dtype without needing the caller to do manual casts. The default on MixedPrecisionPolicy is False, so we set it explicitly here to pin the contract.

Parameters:

param_dtype (torch.dtype)

Return type:

torch.distributed._composable.fsdp.MixedPrecisionPolicy

kempnerforge.distributed.destroy_distributed()[source]

Clean up distributed process groups.

Return type:

None

kempnerforge.distributed.get_dp_mesh(device_mesh)[source]

Extract the data-parallel sub-mesh from a DeviceMesh.

Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.

Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use has_dp_mesh to check first.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

torch.distributed.device_mesh.DeviceMesh

kempnerforge.distributed.get_pp_mesh(device_mesh)[source]

Extract the PP sub-mesh from a DeviceMesh.

Returns None if no ‘pp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.get_pp_rank(device_mesh)[source]

Get the pipeline parallel rank for this process.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.get_pp_size(device_mesh)[source]

Get the pipeline parallel world size.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.get_tp_mesh(device_mesh)[source]

Extract the TP sub-mesh from a DeviceMesh.

Returns None if no ‘tp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.get_world_info()[source]

Return (rank, local_rank, world_size) from environment variables.

Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.

Return type:

tuple[int, int, int]

kempnerforge.distributed.init_distributed(config, seed=42)[source]

Initialize distributed training and build the DeviceMesh.

Parameters:
  • config (DistributedConfig) – Distributed configuration with parallelism dimensions.

  • seed (int) – Random seed for reproducibility.

Returns:

DeviceMesh if world_size > 1, None for single-GPU.

Return type:

DeviceMesh | None

kempnerforge.distributed.is_rank_zero()[source]

Check if the current process is rank 0.

Return type:

bool

Modules

expert_parallel

Expert Parallelism: partition MoE experts across an EP process group.

parallel

Parallelism application: TP, AC, Float8, FSDP2, and model building.

pipeline_parallel

Pipeline parallelism for KempnerForge.

setup

Distributed process group and DeviceMesh initialization.

tensor_parallel

Tensor parallelism sharding plans for model components.

utils

Distributed training utilities.