kempnerforge.distributed.parallel

Parallelism application: TP, AC, Float8, FSDP2, and model building.

Applies parallelism to Transformer models in the correct order.

Application order (critical — wrong order causes silent correctness bugs):
  1. Tensor parallelism (apply_tensor_parallel) — must see raw blocks

  2. Expert parallelism (apply_expert_parallel) — partitions MoE experts

  3. Float8 training (apply_float8) — wraps Linear → Float8Linear

  4. Activation checkpointing (apply_ac) — wraps blocks in CheckpointWrapper

  5. FSDP2 (apply_fsdp2) — shards everything (uses float8 all-gather if enabled)

For convenience, build_parallel_model combines all steps including model creation, meta-device initialization, and optional torch.compile.

Functions

apply_ac(model, mode)

Apply activation checkpointing to the model.

apply_float8(model[, ...])

Apply Float8 training (torchao) to the model.

apply_fsdp2(model, device_mesh[, mp_policy, ...])

Apply FSDP2 (fully_shard) to a Transformer model.

build_parallel_model(model_config, device, ...)

Build a Transformer with parallelism applied in the correct order.

default_mp_policy([param_dtype])

Mixed-precision policy: param_dtype compute, fp32 gradient reduction.

get_dp_mesh(device_mesh)

Extract the data-parallel sub-mesh from a DeviceMesh.

has_dp_mesh(device_mesh)

Check whether the DeviceMesh contains any data-parallel dimensions.

kempnerforge.distributed.parallel.has_dp_mesh(device_mesh)[source]

Check whether the DeviceMesh contains any data-parallel dimensions.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

bool

kempnerforge.distributed.parallel.get_dp_mesh(device_mesh)[source]

Extract the data-parallel sub-mesh from a DeviceMesh.

Returns a 1D mesh (pure sharding) or 2D mesh (replicate + shard / HSDP) depending on which dimensions are present.

Raises ValueError if no DP dimensions exist (e.g., pure TP mesh). Use has_dp_mesh to check first.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

torch.distributed.device_mesh.DeviceMesh

kempnerforge.distributed.parallel.default_mp_policy(param_dtype=torch.bfloat16)[source]

Mixed-precision policy: param_dtype compute, fp32 gradient reduction.

Parameters:

param_dtype (torch.dtype)

Return type:

torch.distributed._composable.fsdp.MixedPrecisionPolicy

kempnerforge.distributed.parallel.apply_ac(model, mode)[source]

Apply activation checkpointing to the model.

Must be called BEFORE apply_fsdp2.

Parameters:
  • model (Transformer) – Transformer model.

  • mode (ActivationCheckpointing) – Checkpointing mode — “none”, “full”, or “selective”. full: checkpoint every TransformerBlock (maximum memory savings). selective: checkpoint only Attention modules (balanced trade-off).

Return type:

None

kempnerforge.distributed.parallel.apply_float8(model, enable_fsdp_float8_all_gather=True)[source]

Apply Float8 training (torchao) to the model.

Converts nn.Linear modules to Float8Linear for E4M3 forward / E5M2 backward with dynamic tensorwise scaling. Master weights remain in bf16.

Must be called AFTER apply_tensor_parallel / apply_expert_parallel and BEFORE apply_ac / apply_fsdp2.

MoE expert modules (experts and shared_expert) are excluded because they use grouped GEMM (torch._grouped_mm) which bypasses Float8Linear.forward().

Parameters:
  • model (Transformer) – Transformer model.

  • enable_fsdp_float8_all_gather (bool) – If True, FSDP2 all-gathers use float8 (halves communication volume). Requires FSDP2 to be applied after. Must be False when TP is active — the float8 weight wrapper calls aten.is_pinned on DTensors, which has no sharding strategy yet.

Return type:

None

kempnerforge.distributed.parallel.apply_fsdp2(model, device_mesh, mp_policy=None, reshard_after_forward=True)[source]

Apply FSDP2 (fully_shard) to a Transformer model.

Shards each TransformerBlock independently, then wraps the top-level model for remaining parameters (embeddings, final norm, output head).

Must be called AFTER apply_ac and apply_tensor_parallel.

EP interaction: Blocks with expert parallelism get per-sub-module wrapping (attention and MoE individually) instead of per-block wrapping. Per-block wrapping would cause FSDP2’s reduce-scatter to fire between EP’s backward all-to-all calls (deadlock). Per-sub-module wrapping avoids this: the MoE reduce-scatter fires after the entire MoE backward (both EP all-to-alls complete), while attention reduce-scatter is EP-free.

Parameters:
  • model (Transformer) – Transformer model to shard.

  • device_mesh (DeviceMesh) – Full DeviceMesh (dp sub-mesh is extracted automatically).

  • mp_policy (MixedPrecisionPolicy | None) – Mixed precision policy. Defaults to bf16 params + fp32 reduce.

  • reshard_after_forward (bool | int) – Whether to free gathered params after forward. True = always reshard (saves memory, default). False = keep gathered (useful when PP needs params across microbatches). int = rate-limit the number of concurrent all-gathers.

Return type:

None

kempnerforge.distributed.parallel.build_parallel_model(model_config, device, device_mesh, *, ac_mode=ActivationCheckpointing.none, mp_policy=None, param_dtype=torch.bfloat16, compile_model=False, fp8=False)[source]

Build a Transformer with parallelism applied in the correct order.

Handles four configurations automatically:
  • TP enabled: meta-device init → TP → EP → [Float8] → AC → FSDP → materialize

  • TP disabled: create on device → EP → [Float8] → AC → FSDP

This is the non-PP model building path. For pipeline parallelism, use build_stage_module + apply parallelism directly.

Parameters:
  • model_config – ModelConfig for the Transformer.

  • device (torch.device) – Target device for the model.

  • device_mesh (DeviceMesh | None) – Full DeviceMesh (may contain tp, dp_shard, dp_replicate dims).

  • ac_mode (ActivationCheckpointing) – Activation checkpointing mode.

  • mp_policy (MixedPrecisionPolicy | None) – FSDP2 mixed-precision policy. Defaults to bf16 params + fp32 reduce.

  • param_dtype (torch.dtype) – Dtype for model parameters.

  • compile_model (bool) – Whether to torch.compile the model.

  • fp8 (bool) – Whether to enable Float8 mixed precision (torchao).

Returns:

The parallelized model, ready for training.

Return type:

torch.nn.Module