kempnerforge.distributed¶
Distributed training infrastructure.
Public API:
init_distributed/destroy_distributed: process group lifecycleapply_fsdp2/apply_ac: FSDP2 sharding and activation checkpointingapply_tensor_parallel: DTensor-based tensor parallelismPipeline parallelism:
build_stage_module,build_pipeline_stage,build_pipeline_scheduleclip_grad_norm_: DTensor-aware gradient clipping
- kempnerforge.distributed.apply_ac(model, mode)[source]¶
Apply activation checkpointing to the model.
Must be called BEFORE apply_fsdp2.
- Parameters:
model (Transformer) – Transformer model.
mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).
- Return type:
None
- kempnerforge.distributed.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]¶
Apply FSDP2 (fully_shard) to a Transformer model.
Shards each TransformerBlock independently, then wraps the top-level model for remaining parameters (embeddings, final norm, output head).
Must be called AFTER apply_ac and apply_tensor_parallel.
EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.
- Parameters:
model (Transformer) – Transformer model to shard.
device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).
mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.
reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.
- Return type:
None
- kempnerforge.distributed.apply_tensor_parallel(model, device_mesh)[source]¶
Apply tensor parallelism to a Transformer or PipelineStageModule.
Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.
PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.
- Parameters:
model (torch.nn.Module) – Transformer or PipelineStageModule to parallelize.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘tp’ dimension.
- Return type:
None
- kempnerforge.distributed.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]¶
Create a pipeline execution schedule.
- Parameters:
stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.
n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.
loss_fn (callable) – Loss function (applied on last stage only).
schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.
- Returns:
First stage: schedule.step(input_tensor, target=labels)
Other stages: schedule.step(target=labels) or schedule.step()
- Return type:
A pipeline schedule object. Call schedule.step() in the training loop
- kempnerforge.distributed.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]¶
Wrap a stage module in a PipelineStage for schedule execution.
- Parameters:
stage_module (PipelineStageModule) – The model chunk for this stage.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘pp’ dimension.
device (torch.device) – Device for this stage.
batch_size (int) – Micro-batch size (for shape inference).
seq_len (int) – Sequence length (for shape inference).
param_dtype (torch.dtype) – Dtype for intermediate activations (matches mixed precision).
- Returns:
A PipelineStage ready for use with a pipeline schedule.
- Return type:
torch.distributed.pipelining.PipelineStage
- kempnerforge.distributed.build_stage_module(config, pp_rank, pp_size)[source]¶
Build the model chunk for a specific pipeline stage.
- Parameters:
config (ModelConfig) – Model configuration.
pp_rank (int) – This process’s pipeline rank (0-indexed).
pp_size (int) – Total number of pipeline stages.
- Returns:
A PipelineStageModule containing only the parameters for this stage.
- Return type:
- kempnerforge.distributed.clip_grad_norm_(model, max_norm, foreach=True)[source]¶
Clip gradient norm across all parameters.
Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.
Falls back to
torch.nn.utils.clip_grad_norm_when there is only one mesh (pure FSDP or single-GPU).- Parameters:
model (torch.nn.Module) – Model whose gradients to clip.
max_norm (float) – Maximum gradient norm.
foreach (bool) – Use the faster foreach implementation.
- Returns:
Total gradient norm (before clipping).
- Return type:
- kempnerforge.distributed.compute_layer_assignment(n_layers, pp_size)[source]¶
Compute which layers go to which PP stage.
Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.
- kempnerforge.distributed.default_mp_policy(param_dtype=torch.bfloat16)[source]¶
Mixed-precision policy: param_dtype compute, fp32 gradient reduction.
- Parameters:
param_dtype (torch.dtype)
- Return type:
torch.distributed._composable.fsdp.MixedPrecisionPolicy
- kempnerforge.distributed.destroy_distributed()[source]¶
Clean up distributed process groups.
- Return type:
None
- kempnerforge.distributed.get_dp_mesh(device_mesh)[source]¶
Extract the data-parallel sub-mesh from a DeviceMesh.
Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.
Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use
has_dp_meshto check first.- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_pp_mesh(device_mesh)[source]¶
Extract the PP sub-mesh from a DeviceMesh.
Returns None if no ‘pp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.get_pp_rank(device_mesh)[source]¶
Get the pipeline parallel rank for this process.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_pp_size(device_mesh)[source]¶
Get the pipeline parallel world size.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_tp_mesh(device_mesh)[source]¶
Extract the TP sub-mesh from a DeviceMesh.
Returns None if no ‘tp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.get_world_info()[source]¶
Return (rank, local_rank, world_size) from environment variables.
Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.
- kempnerforge.distributed.init_distributed(config, seed=42)[source]¶
Initialize distributed training and build the DeviceMesh.
- Parameters:
config (DistributedConfig) – Distributed configuration with parallelism dimensions.
seed (int) – Random seed for reproducibility.
- Returns:
DeviceMesh if world_size > 1, None for single-GPU.
- Return type:
DeviceMesh | None
- kempnerforge.distributed.is_rank_zero()[source]¶
Check if the current process is rank 0.
- Return type:
Modules
Expert Parallelism: partition MoE experts across an EP process group. |
|
Parallelism application: TP, AC, Float8, FSDP2, and model building. |
|
Pipeline parallelism for KempnerForge. |
|
Distributed process group and DeviceMesh initialization. |
|
Tensor parallelism sharding plans for model components. |
|
Distributed training utilities. |