kempnerforge.distributed.tensor_parallel¶
Tensor parallelism sharding plans for model components.
Applies column-parallel and row-parallel sharding to attention and MLP projections using PyTorch’s DTensor-based tensor parallelism.
- Sharding strategy (with SequenceParallel):
Token embedding: Replicated (SequenceParallel norm handles Replicate→Shard(1))
Attention Q/K/V: ColwiseParallel (split heads across TP ranks)
Attention O: RowwiseParallel (gather heads, reduce-scatter to Shard(1))
MLP gate/up: ColwiseParallel (split hidden dim)
MLP down: RowwiseParallel (gather hidden dim, reduce-scatter to Shard(1))
Norm layers: SequenceParallel (operate on sequence-sharded activations)
Final norm: SequenceParallel
Output head: ColwiseParallel (split vocab, gather to Replicate for loss)
SequenceParallel keeps activations sharded along the sequence dimension between blocks, reducing activation memory at norm layers by 1/tp and replacing all-reduce with reduce-scatter in RowwiseParallel.
The token embedding stays replicated because RowwiseParallel on nn.Embedding doesn’t correctly redistribute output to Shard(1) — it relabels without scattering, inflating the global sequence dimension. The first block’s SequenceParallel norm naturally handles the Replicate → Shard(1) transition.
SequenceParallel is disabled when tie_embeddings=True (ColwiseParallel on the output head imposes incompatible sharding on the shared weight).
Pipeline parallel stages use basic TP (no SequenceParallel) to avoid DTensors at PP stage boundaries.
Functions
|
Apply tensor parallelism to a Transformer or PipelineStageModule. |
|
Extract the TP sub-mesh from a DeviceMesh. |
- kempnerforge.distributed.tensor_parallel.get_tp_mesh(device_mesh)[source]¶
Extract the TP sub-mesh from a DeviceMesh.
Returns None if no ‘tp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.tensor_parallel.apply_tensor_parallel(model, device_mesh)[source]¶
Apply tensor parallelism to a Transformer or PipelineStageModule.
Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.
PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.
- Parameters:
model (torch.nn.Module) – Transformer or PipelineStageModule to parallelize.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘tp’ dimension.
- Return type:
None