kempnerforge.distributed.parallel¶
Parallelism application: TP, AC, Float8, FSDP2, and model building.
Applies parallelism to Transformer models in the correct order.
- Application order (critical — wrong order causes silent correctness bugs):
Tensor parallelism (apply_tensor_parallel) — must see raw blocks
Expert parallelism (apply_expert_parallel) — partitions MoE experts
Float8 training (apply_float8) — wraps Linear → Float8Linear
Activation checkpointing (apply_ac) — wraps blocks in CheckpointWrapper
FSDP2 (apply_fsdp2) — shards everything (uses float8 all-gather if enabled)
For convenience, build_parallel_model combines all steps including
model creation, meta-device initialization, and optional torch.compile.
It also dispatches to a VLM branch when vlm_config is provided.
Functions
|
Apply activation checkpointing to the model. |
|
Apply Float8 training (torchao) to the model. |
|
Apply FSDP2 (fully_shard) to a Transformer model. |
|
Build a Transformer (or a VLMWrapper) with parallelism applied. |
|
Mixed-precision policy: param_dtype compute, fp32 gradient reduction. |
|
Extract the data-parallel sub-mesh from a DeviceMesh. |
|
Check whether the DeviceMesh contains any data-parallel dimensions. |
- kempnerforge.distributed.parallel.has_dp_mesh(device_mesh)[source]¶
Check whether the DeviceMesh contains any data-parallel dimensions.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.parallel.get_dp_mesh(device_mesh)[source]¶
Extract the data-parallel sub-mesh from a DeviceMesh.
Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.
Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use
has_dp_meshto check first.- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.parallel.default_mp_policy(param_dtype=torch.bfloat16)[source]¶
Mixed-precision policy: param_dtype compute, fp32 gradient reduction.
cast_forward_inputs=Trueensures FSDP2 casts input tensors to the declaredparam_dtypeat each wrapped module’s forward boundary. The VLM path relies on this so image embeddings produced by the adapter (bf16) reach the sharded transformer with matching dtype without needing the caller to do manual casts. The default onMixedPrecisionPolicyis False, so we set it explicitly here to pin the contract.- Parameters:
param_dtype (torch.dtype)
- Return type:
torch.distributed._composable.fsdp.MixedPrecisionPolicy
- kempnerforge.distributed.parallel.apply_ac(model, mode)[source]¶
Apply activation checkpointing to the model.
Must be called BEFORE apply_fsdp2.
- Parameters:
model (Transformer) – Transformer model.
mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).
- Return type:
None
- kempnerforge.distributed.parallel.apply_float8(model, enable_fsdp_float8_all_gather=True)[source]¶
Apply Float8 training (torchao) to the model.
Converts nn.Linear modules to Float8Linear for E4M3 forward / E5M2 backward with dynamic tensorwise scaling. Master weights remain in bf16.
Must be called AFTER apply_tensor_parallel / apply_expert_parallel and BEFORE apply_ac / apply_fsdp2.
MoE expert modules (experts and shared_expert) are excluded because they use grouped GEMM (torch._grouped_mm) which bypasses Float8Linear.forward().
- Parameters:
model (Transformer) – Transformer model.
enable_fsdp_float8_all_gather (bool) – If True, FSDP2 all-gathers use float8 (halves communication volume). Requires FSDP2 to be applied after. Must be False when TP is active — the float8 weight wrapper calls aten.is_pinned on DTensors, which has no sharding strategy yet.
- Return type:
None
- kempnerforge.distributed.parallel.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]¶
Apply FSDP2 (fully_shard) to a Transformer model.
Shards each TransformerBlock independently (via
_fsdp_wrap_transformer_blocksso the EP-MoE per-sub-module wrap is shared with the VLM path), then wraps the top-level model for remaining parameters (embeddings, final norm, output head).Must be called AFTER apply_ac and apply_tensor_parallel.
EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.
- Parameters:
model (Transformer) – Transformer model to shard.
device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).
mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.
reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.
- Return type:
None
- kempnerforge.distributed.parallel.build_parallel_model(model_config, device, device_mesh, *, vision_config=None, adapter_config=None, vlm_config=None, ac_mode=ActivationCheckpointing.none, mp_policy=None, param_dtype=torch.bfloat16, compile_model=False, fp8=False)[source]¶
Build a Transformer (or a VLMWrapper) with parallelism applied.
Dispatches on
vlm_config(None-> text-only path). Non-VLM configurations follow the original order:TP enabled: meta-device init -> TP -> EP -> [Float8] -> AC -> FSDP -> materialize
TP disabled: create on device -> EP -> [Float8] -> AC -> FSDP
VLM configurations follow the order documented on
_build_vlm.This is the non-PP model building path. For pipeline parallelism, use
build_stage_module+ apply parallelism directly.- Parameters:
model_config – ModelConfig for the Transformer.
device (torch.device) – Target device for the model.
device_mesh (DeviceMesh | None) – Full DeviceMesh (may contain tp, dp_shard, dp_replicate dims).
vision_config – VisionEncoderConfig (required iff vlm_config is set).
adapter_config – AdapterConfig (required iff vlm_config is set).
vlm_config – VLMConfig. None for a pure text-only run.
ac_mode (ActivationCheckpointing) – Activation checkpointing mode.
mp_policy (MixedPrecisionPolicy | None) – FSDP2 mixed-precision policy. Defaults to bf16 params + fp32 reduce.
param_dtype (torch.dtype) – Dtype for model parameters.
compile_model (bool) – Whether to torch.compile the model.
fp8 (bool) – Whether to enable Float8 mixed precision (torchao).
- Returns:
The parallelized model, ready for training.
- Return type: