kempnerforge.distributed.parallel¶
Parallelism application: TP, AC, Float8, FSDP2, and model building.
Applies parallelism to Transformer models in the correct order.
- Application order (critical — wrong order causes silent correctness bugs):
Tensor parallelism (apply_tensor_parallel) — must see raw blocks
Expert parallelism (apply_expert_parallel) — partitions MoE experts
Float8 training (apply_float8) — wraps Linear → Float8Linear
Activation checkpointing (apply_ac) — wraps blocks in CheckpointWrapper
FSDP2 (apply_fsdp2) — shards everything (uses float8 all-gather if enabled)
For convenience, build_parallel_model combines all steps including
model creation, meta-device initialization, and optional torch.compile.
Functions
|
Apply activation checkpointing to the model. |
|
Apply Float8 training (torchao) to the model. |
|
Apply FSDP2 (fully_shard) to a Transformer model. |
|
Build a Transformer with parallelism applied in the correct order. |
|
Mixed-precision policy: param_dtype compute, fp32 gradient reduction. |
|
Extract the data-parallel sub-mesh from a DeviceMesh. |
|
Check whether the DeviceMesh contains any data-parallel dimensions. |
- kempnerforge.distributed.parallel.has_dp_mesh(device_mesh)[source]¶
Check whether the DeviceMesh contains any data-parallel dimensions.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.parallel.get_dp_mesh(device_mesh)[source]¶
Extract the data-parallel sub-mesh from a DeviceMesh.
Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.
Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use
has_dp_meshto check first.- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.parallel.default_mp_policy(param_dtype=torch.bfloat16)[source]¶
Mixed-precision policy: param_dtype compute, fp32 gradient reduction.
- Parameters:
param_dtype (torch.dtype)
- Return type:
torch.distributed._composable.fsdp.MixedPrecisionPolicy
- kempnerforge.distributed.parallel.apply_ac(model, mode)[source]¶
Apply activation checkpointing to the model.
Must be called BEFORE apply_fsdp2.
- Parameters:
model (Transformer) – Transformer model.
mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).
- Return type:
None
- kempnerforge.distributed.parallel.apply_float8(model, enable_fsdp_float8_all_gather=True)[source]¶
Apply Float8 training (torchao) to the model.
Converts nn.Linear modules to Float8Linear for E4M3 forward / E5M2 backward with dynamic tensorwise scaling. Master weights remain in bf16.
Must be called AFTER apply_tensor_parallel / apply_expert_parallel and BEFORE apply_ac / apply_fsdp2.
MoE expert modules (experts and shared_expert) are excluded because they use grouped GEMM (torch._grouped_mm) which bypasses Float8Linear.forward().
- Parameters:
model (Transformer) – Transformer model.
enable_fsdp_float8_all_gather (bool) – If True, FSDP2 all-gathers use float8 (halves communication volume). Requires FSDP2 to be applied after. Must be False when TP is active — the float8 weight wrapper calls aten.is_pinned on DTensors, which has no sharding strategy yet.
- Return type:
None
- kempnerforge.distributed.parallel.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]¶
Apply FSDP2 (fully_shard) to a Transformer model.
Shards each TransformerBlock independently, then wraps the top-level model for remaining parameters (embeddings, final norm, output head).
Must be called AFTER apply_ac and apply_tensor_parallel.
EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.
- Parameters:
model (Transformer) – Transformer model to shard.
device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).
mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.
reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.
- Return type:
None
- kempnerforge.distributed.parallel.build_parallel_model(model_config, device, device_mesh, *, ac_mode=ActivationCheckpointing.none, mp_policy=None, param_dtype=torch.bfloat16, compile_model=False, fp8=False)[source]¶
Build a Transformer with parallelism applied in the correct order.
- Handles four configurations automatically:
TP enabled: meta-device init → TP → EP → [Float8] → AC → FSDP → materialize
TP disabled: create on device → EP → [Float8] → AC → FSDP
This is the non-PP model building path. For pipeline parallelism, use
build_stage_module+ apply parallelism directly.- Parameters:
model_config – ModelConfig for the Transformer.
device (torch.device) – Target device for the model.
device_mesh (DeviceMesh | None) – Full DeviceMesh (may contain tp, dp_shard, dp_replicate dims).
ac_mode (ActivationCheckpointing) – Activation checkpointing mode.
mp_policy (MixedPrecisionPolicy | None) – FSDP2 mixed-precision policy. Defaults to bf16 params + fp32 reduce.
param_dtype (torch.dtype) – Dtype for model parameters.
compile_model (bool) – Whether to torch.compile the model.
fp8 (bool) – Whether to enable Float8 mixed precision (torchao).
- Returns:
The parallelized model, ready for training.
- Return type: