kempnerforge.distributed.tensor_parallel

Tensor parallelism sharding plans for model components.

Applies column-parallel and row-parallel sharding to attention and MLP projections using PyTorch’s DTensor-based tensor parallelism.

Sharding strategy (with SequenceParallel):
  • Token embedding: Replicated (SequenceParallel norm handles Replicate→Shard(1))

  • Attention Q/K/V: ColwiseParallel (split heads across TP ranks)

  • Attention O: RowwiseParallel (gather heads, reduce-scatter to Shard(1))

  • MLP gate/up: ColwiseParallel (split hidden dim)

  • MLP down: RowwiseParallel (gather hidden dim, reduce-scatter to Shard(1))

  • Norm layers: SequenceParallel (operate on sequence-sharded activations)

  • Final norm: SequenceParallel

  • Output head: ColwiseParallel (split vocab, gather to Replicate for loss)

SequenceParallel keeps activations sharded along the sequence dimension between blocks, reducing activation memory at norm layers by 1/tp and replacing all-reduce with reduce-scatter in RowwiseParallel.

The token embedding stays replicated because RowwiseParallel on nn.Embedding doesn’t correctly redistribute output to Shard(1) — it relabels without scattering, inflating the global sequence dimension. The first block’s SequenceParallel norm naturally handles the Replicate → Shard(1) transition.

SequenceParallel is disabled when tie_embeddings=True (ColwiseParallel on the output head imposes incompatible sharding on the shared weight).

Pipeline parallel stages use basic TP (no SequenceParallel) to avoid DTensors at PP stage boundaries.

Functions

apply_tensor_parallel(model, device_mesh)

Apply tensor parallelism to a Transformer or PipelineStageModule.

get_tp_mesh(device_mesh)

Extract the TP sub-mesh from a DeviceMesh.

kempnerforge.distributed.tensor_parallel.get_tp_mesh(device_mesh)[source]

Extract the TP sub-mesh from a DeviceMesh.

Returns None if no ‘tp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.tensor_parallel.apply_tensor_parallel(model, device_mesh)[source]

Apply tensor parallelism to a Transformer or PipelineStageModule.

Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2.

PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied.

Parameters:
Return type:

None