Source code for kempnerforge.distributed.tensor_parallel

"""Tensor parallelism sharding plans for model components.

Applies column-parallel and row-parallel sharding to attention and MLP
projections using PyTorch's DTensor-based tensor parallelism.

Sharding strategy (with SequenceParallel):
  - Token embedding:  Replicated (SequenceParallel norm handles Replicate→Shard(1))
  - Attention Q/K/V:  ColwiseParallel (split heads across TP ranks)
  - Attention O:      RowwiseParallel (gather heads, reduce-scatter to Shard(1))
  - MLP gate/up:      ColwiseParallel (split hidden dim)
  - MLP down:         RowwiseParallel (gather hidden dim, reduce-scatter to Shard(1))
  - Norm layers:      SequenceParallel (operate on sequence-sharded activations)
  - Final norm:       SequenceParallel
  - Output head:      ColwiseParallel (split vocab, gather to Replicate for loss)

SequenceParallel keeps activations sharded along the sequence dimension
between blocks, reducing activation memory at norm layers by 1/tp and
replacing all-reduce with reduce-scatter in RowwiseParallel.

The token embedding stays replicated because RowwiseParallel on nn.Embedding
doesn't correctly redistribute output to Shard(1) — it relabels without
scattering, inflating the global sequence dimension. The first block's
SequenceParallel norm naturally handles the Replicate → Shard(1) transition.

SequenceParallel is disabled when tie_embeddings=True (ColwiseParallel on
the output head imposes incompatible sharding on the shared weight).

Pipeline parallel stages use basic TP (no SequenceParallel) to avoid
DTensors at PP stage boundaries.
"""

from __future__ import annotations

import logging

import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    SequenceParallel,
    parallelize_module,
)

from kempnerforge.model.mlp import SwiGLUMLP
from kempnerforge.model.moe import MoEMLP

logger = logging.getLogger(__name__)


def _build_block_tp_plan(block, *, sequence_parallel: bool = True) -> dict:
    """Build a TP sharding plan for a single TransformerBlock.

    When sequence_parallel=True, norm layers use SequenceParallel and
    projections use Shard(1) input/output layouts so activations stay
    sequence-sharded between blocks. When False, activations are
    Replicate between blocks (basic TP).
    """
    col_kw = {"input_layouts": Shard(1)} if sequence_parallel else {}
    row_kw = {"output_layouts": Shard(1)} if sequence_parallel else {}

    plan = {}

    if sequence_parallel:
        plan["attention_norm"] = SequenceParallel()
        plan["mlp_norm"] = SequenceParallel()

    plan["attention.q_proj"] = ColwiseParallel(**col_kw)
    plan["attention.k_proj"] = ColwiseParallel(**col_kw)
    plan["attention.v_proj"] = ColwiseParallel(**col_kw)
    plan["attention.o_proj"] = RowwiseParallel(**row_kw)

    # MoE blocks: skip all mlp.* entries — experts and router stay replicated.
    # Dense blocks: shard MLP projections as before.
    if isinstance(block.mlp, MoEMLP):
        pass  # No MLP entries — experts replicated, TP on attention only
    elif isinstance(block.mlp, SwiGLUMLP):
        plan["mlp.gate_proj"] = ColwiseParallel(**col_kw)
        plan["mlp.up_proj"] = ColwiseParallel(**col_kw)
        plan["mlp.down_proj"] = RowwiseParallel(**row_kw)
    else:
        plan["mlp.up_proj"] = ColwiseParallel(**col_kw)
        plan["mlp.down_proj"] = RowwiseParallel(**row_kw)

    return plan


[docs] def get_tp_mesh(device_mesh: DeviceMesh) -> DeviceMesh | None: """Extract the TP sub-mesh from a DeviceMesh. Returns None if no 'tp' dimension exists. """ if "tp" not in device_mesh.mesh_dim_names: # type: ignore[reportOperatorIssue] return None return device_mesh["tp"]
[docs] def apply_tensor_parallel( model: nn.Module, device_mesh: DeviceMesh, ) -> None: """Apply tensor parallelism to a Transformer or PipelineStageModule. Parallelizes attention/MLP projections, norm layers (SequenceParallel), and output head. Token embedding stays replicated. Should be called BEFORE apply_ac and apply_fsdp2. PP stages use basic TP without SequenceParallel to avoid DTensors at stage boundaries. SequenceParallel is also disabled when weights are tied. Args: model: Transformer or PipelineStageModule to parallelize. device_mesh: Full DeviceMesh with a 'tp' dimension. """ tp_mesh = get_tp_mesh(device_mesh) if tp_mesh is None: logger.warning("No 'tp' dimension in DeviceMesh — skipping tensor parallelism") return tie = getattr(getattr(model, "config", None), "tie_embeddings", False) is_pp_stage = hasattr(model, "stage_id") is_moe = getattr(getattr(model, "config", None), "is_moe", False) # SequenceParallel requires TP on embedding/head (otherwise the unsharded # output head can't consume Shard(1) input from the final norm). # Disabled for MoE: boolean indexing in expert dispatch breaks Shard(1) DTensors, # and alternating SP-on/SP-off blocks create DTensor transition errors. seq_parallel = not is_pp_stage and not tie and not is_moe # Token embedding: wrap output as DTensor Replicate so the first block's # SequenceParallel norm properly redistributes to Shard(1). Without this, # SequenceParallel receives a plain tensor and labels it Shard(1) without # actually scattering, inflating the global sequence dimension. if seq_parallel and getattr(model, "token_embedding", None) is not None: from torch.distributed.tensor import DTensor def _wrap_replicate(module, input, output, mesh=tp_mesh): return DTensor.from_local(output, device_mesh=mesh, placements=[Replicate()]) model.token_embedding.register_forward_hook(_wrap_replicate) # type: ignore[reportAttributeAccessIssue] # Transformer blocks for block in model.layers.values(): # type: ignore[reportCallIssue] plan = _build_block_tp_plan(block, sequence_parallel=seq_parallel) parallelize_module(block, tp_mesh, plan) # type: ignore[reportArgumentType] # Re-wrap attention/MLP outputs as DTensor Shard(1). Operations inside # attention (SDPA, view, contiguous) strip DTensor metadata, causing # "mixed torch.Tensor and DTensor" errors at the residual connection. if seq_parallel: from torch.distributed.tensor import DTensor def _rewrap_shard1(module, input, output, mesh=tp_mesh): if not isinstance(output, DTensor): return DTensor.from_local(output, device_mesh=mesh, placements=[Shard(1)]) return output block.attention.register_forward_hook(_rewrap_shard1) # type: ignore[reportAttributeAccessIssue] block.mlp.register_forward_hook(_rewrap_shard1) # type: ignore[reportAttributeAccessIssue] # Final norm: SequenceParallel (non-PP, non-tied only) if seq_parallel and getattr(model, "norm", None) is not None: parallelize_module(model, tp_mesh, {"norm": SequenceParallel()}) # Output head: split vocab dim, gather to Replicate for loss computation. # Only when seq_parallel=True — matches the Shard(1) data flow from the final norm. if seq_parallel and not tie and getattr(model, "output_head", None) is not None: parallelize_module( model.output_head, # type: ignore[reportArgumentType] tp_mesh, {"proj": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate())}, ) logger.info( f"Applied tensor parallelism: tp_degree={tp_mesh.size()}, " f"layers={len(model.layers)}, sequence_parallel={seq_parallel}" # type: ignore[reportArgumentType] )