kempnerforge.distributed¶
Distributed training infrastructure.
Public API:
init_distributed/destroy_distributed: process group lifecycleapply_fsdp2/apply_ac: FSDP2 sharding and activation checkpointingapply_tensor_parallel: DTensor-based tensor parallelismPipeline parallelism:
build_stage_module,build_pipeline_stage,build_pipeline_scheduleclip_grad_norm_: DTensor-aware gradient clipping
- kempnerforge.distributed.apply_ac(model, mode)[source]¶
Apply activation checkpointing to the model.
Must be called BEFORE apply_fsdp2.
- Parameters:
model (Transformer) – Transformer model.
mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).
- Return type:
None
- kempnerforge.distributed.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]¶
Apply FSDP2 (fully_shard) to a Transformer model.
Shards each TransformerBlock independently (via
_fsdp_wrap_transformer_blocksso the EP-MoE per-sub-module wrap is shared with the VLM path), then wraps the top-level model for remaining parameters (embeddings, final norm, output head).Must be called AFTER apply_ac and apply_tensor_parallel.
EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.
- Parameters:
model (Transformer) – Transformer model to shard.
device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).
mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.
reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.
- Return type:
None
- kempnerforge.distributed.apply_tensor_parallel(model, device_mesh)[source]¶
Apply tensor parallelism to a Transformer or PipelineStageModule.
Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.
PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.
- Parameters:
model (torch.nn.Module) – Transformer or PipelineStageModule to parallelize.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘tp’ dimension.
- Return type:
None
- kempnerforge.distributed.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]¶
Create a pipeline execution schedule.
- Parameters:
stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.
n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.
loss_fn (callable) – Loss function (applied on last stage only).
schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.
- Returns:
First stage: schedule.step(input_tensor, target=labels)
Other stages: schedule.step(target=labels) or schedule.step()
- Return type:
A pipeline schedule object. Call schedule.step() in the training loop
- kempnerforge.distributed.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]¶
Wrap a stage module in a PipelineStage for schedule execution.
- Parameters:
stage_module (PipelineStageModule) – The model chunk for this stage.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘pp’ dimension.
device (torch.device) – Device for this stage.
batch_size (int) – Micro-batch size (for shape inference).
seq_len (int) – Sequence length (for shape inference).
param_dtype (torch.dtype) – Dtype for intermediate activations (matches mixed precision).
- Returns:
A PipelineStage ready for use with a pipeline schedule.
- Return type:
torch.distributed.pipelining.PipelineStage
- kempnerforge.distributed.build_stage_module(config, pp_rank, pp_size)[source]¶
Build the model chunk for a specific pipeline stage.
- Parameters:
config (ModelConfig) – Model configuration.
pp_rank (int) – This process’s pipeline rank (0-indexed).
pp_size (int) – Total number of pipeline stages.
- Returns:
A PipelineStageModule containing only the parameters for this stage.
- Return type:
- kempnerforge.distributed.clip_grad_norm_(model, max_norm, foreach=True)[source]¶
Clip gradient norm across all parameters.
Handles mixed DTensor meshes that arise when TP+FSDP produces parameters on different meshes (e.g., TP-sharded linears on (dp_shard, tp) vs FSDP-only norms on (dp_shard)). Groups gradients by mesh, computes per-group norms via stack, then combines across groups.
Falls back to
torch.nn.utils.clip_grad_norm_when there is only one mesh (pure FSDP or single-GPU).- Parameters:
model (torch.nn.Module) – Model whose gradients to clip.
max_norm (float) – Maximum gradient norm.
foreach (bool) – Use the faster foreach implementation.
- Returns:
Total gradient norm (before clipping).
- Return type:
- kempnerforge.distributed.compute_layer_assignment(n_layers, pp_size)[source]¶
Compute which layers go to which PP stage.
Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.
- kempnerforge.distributed.default_mp_policy(param_dtype=torch.bfloat16)[source]¶
Mixed-precision policy: param_dtype compute, fp32 gradient reduction.
cast_forward_inputs=Trueensures FSDP2 casts input tensors to the declaredparam_dtypeat each wrapped module’s forward boundary. The VLM path relies on this so image embeddings produced by the adapter (bf16) reach the sharded transformer with matching dtype without needing the caller to do manual casts. The default onMixedPrecisionPolicyis False, so we set it explicitly here to pin the contract.- Parameters:
param_dtype (torch.dtype)
- Return type:
torch.distributed._composable.fsdp.MixedPrecisionPolicy
- kempnerforge.distributed.destroy_distributed()[source]¶
Clean up distributed process groups.
- Return type:
None
- kempnerforge.distributed.get_dp_mesh(device_mesh)[source]¶
Extract the data-parallel sub-mesh from a DeviceMesh.
Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.
Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use
has_dp_meshto check first.- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_pp_mesh(device_mesh)[source]¶
Extract the PP sub-mesh from a DeviceMesh.
Returns None if no ‘pp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.get_pp_rank(device_mesh)[source]¶
Get the pipeline parallel rank for this process.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_pp_size(device_mesh)[source]¶
Get the pipeline parallel world size.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.get_tp_mesh(device_mesh)[source]¶
Extract the TP sub-mesh from a DeviceMesh.
Returns None if no ‘tp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.get_world_info()[source]¶
Return (rank, local_rank, world_size) from environment variables.
Works with both torchrun (RANK/LOCAL_RANK/WORLD_SIZE) and direct srun launch (SLURM_PROCID/SLURM_LOCALID/SLURM_NTASKS). When running under srun, also sets RANK/LOCAL_RANK/WORLD_SIZE so that PyTorch’s env:// rendezvous can find them.
- kempnerforge.distributed.init_distributed(config, seed=42)[source]¶
Initialize distributed training and build the DeviceMesh.
- Parameters:
config (DistributedConfig) – Distributed configuration with parallelism dimensions.
seed (int) – Random seed for reproducibility.
- Returns:
DeviceMesh if world_size > 1, None for single-GPU.
- Return type:
DeviceMesh | None
- kempnerforge.distributed.is_rank_zero()[source]¶
Check if the current process is rank 0.
- Return type:
Modules
Expert Parallelism: partition MoE experts across an EP process group. |
|
Parallelism application: TP, AC, Float8, FSDP2, and model building. |
|
Pipeline parallelism for KempnerForge. |
|
Distributed process group and DeviceMesh initialization. |
|
Tensor parallelism sharding plans for model components. |
|
Distributed training utilities. |