kempnerforge.distributed.pipeline_parallel

Pipeline parallelism for KempnerForge.

Splits a Transformer model across pipeline stages, assigning layer ranges to different ranks. Uses torch.distributed.pipelining for schedule execution.

Stage assignment:
  • Stage 0: token_embedding + first chunk of transformer layers

  • Middle stages: transformer layer chunks only

  • Last stage: last chunk of layers + final norm + output head

Application order when combining parallelisms:
  1. Build per-stage model via build_stage_module()

  2. Tensor parallelism (per stage, via apply_tensor_parallel) — must see raw blocks

  3. Activation checkpointing (per stage, via apply_ac)

  4. FSDP2 (per stage, via apply_fsdp2 with reshard_after_forward=False)

Note on FSDP reshard policy:

When using PP, set reshard_after_forward=False in apply_fsdp2 to avoid per-microbatch all-gathers. PP schedules send multiple microbatches through each stage, so keeping gathered params avoids redundant communication.

Functions

build_pipeline_schedule(stage, ...[, schedule])

Create a pipeline execution schedule.

build_pipeline_stage(stage_module, ...[, ...])

Wrap a stage module in a PipelineStage for schedule execution.

build_stage_module(config, pp_rank, pp_size)

Build the model chunk for a specific pipeline stage.

compute_layer_assignment(n_layers, pp_size)

Compute which layers go to which PP stage.

get_pp_mesh(device_mesh)

Extract the PP sub-mesh from a DeviceMesh.

get_pp_rank(device_mesh)

Get the pipeline parallel rank for this process.

get_pp_size(device_mesh)

Get the pipeline parallel world size.

Classes

PipelineStageModule

A model chunk for a single pipeline stage.

kempnerforge.distributed.pipeline_parallel.get_pp_mesh(device_mesh)[source]

Extract the PP sub-mesh from a DeviceMesh.

Returns None if no ‘pp’ dimension exists.

Parameters:

device_mesh (DeviceMesh)

Return type:

DeviceMesh | None

kempnerforge.distributed.pipeline_parallel.get_pp_rank(device_mesh)[source]

Get the pipeline parallel rank for this process.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.pipeline_parallel.get_pp_size(device_mesh)[source]

Get the pipeline parallel world size.

Parameters:

device_mesh (torch.distributed.device_mesh.DeviceMesh)

Return type:

int

kempnerforge.distributed.pipeline_parallel.compute_layer_assignment(n_layers, pp_size)[source]

Compute which layers go to which PP stage.

Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.

Parameters:
  • n_layers (int) – Total number of transformer layers.

  • pp_size (int) – Number of pipeline stages.

Returns:

List of (start_layer, end_layer) tuples, one per stage. end_layer is exclusive.

Raises:

ValueError – If pp_size > n_layers.

Return type:

list[tuple[int, int]]

class kempnerforge.distributed.pipeline_parallel.PipelineStageModule[source]

Bases: Module

A model chunk for a single pipeline stage.

Only allocates parameters needed for this stage:
  • Stage 0: token_embedding + assigned transformer layers

  • Middle stages: assigned transformer layers only

  • Last stage: assigned transformer layers + final norm + output head

Layer keys in self.layers match the full Transformer (e.g. “4”, “5”, “6”) to maintain DCP checkpoint compatibility.

__init__(config, stage_id, num_stages, layer_range)[source]
Parameters:
Return type:

None

init_weights_and_freqs()[source]

Initialize weights and RoPE frequencies after meta-device materialization.

Return type:

None

forward(x)[source]

Forward pass for this pipeline stage.

Parameters:

x (torch.Tensor) – For stage 0: token IDs of shape (batch, seq_len). For other stages: hidden states of shape (batch, seq_len, dim).

Returns:

logits of shape (batch, seq_len, vocab_size). For other stages: hidden states of shape (batch, seq_len, dim).

Return type:

For last stage

kempnerforge.distributed.pipeline_parallel.build_stage_module(config, pp_rank, pp_size)[source]

Build the model chunk for a specific pipeline stage.

Parameters:
  • config (ModelConfig) – Model configuration.

  • pp_rank (int) – This process’s pipeline rank (0-indexed).

  • pp_size (int) – Total number of pipeline stages.

Returns:

A PipelineStageModule containing only the parameters for this stage.

Return type:

PipelineStageModule

kempnerforge.distributed.pipeline_parallel.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]

Wrap a stage module in a PipelineStage for schedule execution.

Parameters:
Returns:

A PipelineStage ready for use with a pipeline schedule.

Return type:

torch.distributed.pipelining.PipelineStage

kempnerforge.distributed.pipeline_parallel.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]

Create a pipeline execution schedule.

Parameters:
  • stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.

  • n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.

  • loss_fn (callable) – Loss function (applied on last stage only).

  • schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.

Returns:

  • First stage: schedule.step(input_tensor, target=labels)

  • Other stages: schedule.step(target=labels) or schedule.step()

Return type:

A pipeline schedule object. Call schedule.step() in the training loop