Tensor parallelism¶
apply_tensor_parallel(model, device_mesh) in
kempnerforge/distributed/tensor_parallel.py
shards attention and MLP weights across the tp mesh dimension using
PyTorch’s parallelize_module + DTensor.
What gets sharded¶
Per TransformerBlock:
Module |
Style |
Reason |
|---|---|---|
|
|
Split attention heads across TP ranks |
|
|
Gather heads, reduce-scatter back to sequence-sharded |
|
|
Split hidden dim |
|
|
Gather hidden, reduce-scatter |
|
|
Norm computed on sequence-sharded input (see below) |
MoE blocks omit the MLP entries: experts are replicated across TP (TP doesn’t shard expert weights), EP handles that dimension separately.
At the model level:
Module |
Style |
When |
|---|---|---|
|
Forward hook wraps output as |
When sequence parallel is on |
|
|
When sequence parallel is on |
|
|
When sequence parallel is on and embeddings not tied |
SequenceParallel¶
When enabled, norms operate on the sequence-sharded tensor (each
TP rank holds seq_len / tp positions) rather than the
replicated one. Activations flow like this through a block:
residual: Shard(1) (seq-sharded)
attention_norm: Shard(1) → Shard(1) (SequenceParallel)
q/k/v_proj: Shard(1) → Shard(2) (Colwise all-gathers seq, shards heads)
attention: Shard(2) compute
o_proj: Shard(2) → Shard(1) (Rowwise gathers heads, reduce-scatters seq)
+ residual: Shard(1)
mlp_norm: Shard(1) → Shard(1)
gate/up_proj: Shard(1) → Shard(2)
down_proj: Shard(2) → Shard(1)
+ residual: Shard(1)
Net effect: activations are kept sequence-sharded everywhere except
inside attention and MLP compute. Memory for activations drops by
tp× at block boundaries, and the extra all-gathers are overlapped
with compute.
When SequenceParallel is disabled¶
Three cases, all handled automatically by apply_tensor_parallel:
Condition |
Why SP is off |
|---|---|
|
DTensor state at PP stage boundaries causes type-coercion issues |
|
ColwiseParallel on the tied output head needs a different sharding than the embedding forward hook can provide |
|
Boolean indexing in MoE dispatch ( |
In these cases the block gets “basic TP” — column/row parallel on the Linears, no sharding of the sequence dimension. Activations are fully replicated across the TP group, which costs more memory but works.
Token embedding hook¶
With SequenceParallel on, the final norm expects its input to be
Replicate()-placed (so it can redistribute to Shard(1) internally).
The token embedding output is a plain tensor by default — without a
hook, SequenceParallel would reinterpret it as Shard(1) without
actually scattering, inflating the global sequence dim by tp.
apply_tensor_parallel installs a forward hook on token_embedding:
def _embed_hook(module, input, output):
return DTensor.from_local(output, mesh, placements=[Replicate()])
Effect: the embedding output is labeled as already-replicated, so
SequenceParallel redistributes to Shard(1) correctly.
Post-hooks on attention and MLP¶
Operations inside attention (SDPA, view, contiguous) and MLP
strip DTensor metadata on their outputs — you get a plain Tensor
where the next + residual expects a Shard(1) DTensor.
apply_tensor_parallel wraps the outputs back into DTensors via
forward post-hooks:
def _wrap_shard1(module, input, output):
return DTensor.from_local(output, mesh, placements=[Shard(1)])
Without this, you get “mixed Tensor and DTensor” errors at the residual add. With it, DTensor metadata is restored and the rest of the block flows correctly.
Meta-device init¶
When TP is active, the model is built on torch.device("meta")
first — no storage is allocated yet. apply_tensor_parallel wraps
the (still-meta) parameters as DTensors, then apply_fsdp2 further
shards them, and only model.to_empty(device=device) allocates real
storage afterward.
The sequence is in FSDP2 § Application order.
Head divisibility¶
n_heads % tp == 0 and n_kv_heads % tp == 0 are required —
ColwiseParallel splits heads evenly across TP ranks, and uneven
splits are rejected.
For Llama-3 70B (n_heads=64, n_kv_heads=8), tp must be in {1, 2, 4, 8} — and tp=8 means 1 KV head per TP rank, which is usually
where the GQA sharing stops paying off.
TP across vs within nodes¶
Within a node (NVLink): TP’s all-gathers and reduce-scatters are on 900 GB/s links — fast enough to overlap with compute.
Across nodes (InfiniBand): TP over IB is slow; you’ll see it in the MFU numbers as soon as
tp > GPUs_per_node.
Shipped recipes keep tp ≤ 4 (single-node TP on 4×H200 per node) —
see Parallelism recipes.
When to use TP¶
Situation |
Use TP? |
|---|---|
FSDP alone fits |
No — at 7B on 4 GPUs, |
|
Yes — memory forces it |
Need to shard attention weights (huge |
Yes |
MoE |
Yes, for attention only — TP on experts is replaced by EP |
The rule of thumb and supporting benchmarks are in Parallelism recipes § Choosing a parallelism combination.
See also¶
Device mesh — how the
tpsub-mesh is built.FSDP2 — runs after TP and shards the remaining params.
Expert parallelism — what happens to MoE experts (TP replicates, EP shards).
Parallelism order — the full sequence with reasoning.
Configuration § Validation rules — the head-divisibility check.