kempnerforge.distributed.pipeline_parallel¶
Pipeline parallelism for KempnerForge.
Splits a Transformer model across pipeline stages, assigning layer ranges to different ranks. Uses torch.distributed.pipelining for schedule execution.
- Stage assignment:
Stage 0: token_embedding + first chunk of transformer layers
Middle stages: transformer layer chunks only
Last stage: last chunk of layers + final norm + output head
- Application order when combining parallelisms:
Build per-stage model via build_stage_module()
Tensor parallelism (per stage, via apply_tensor_parallel) — must see raw blocks
Activation checkpointing (per stage, via apply_ac)
FSDP2 (per stage, via apply_fsdp2 with reshard_after_forward=False)
- Note on FSDP reshard policy:
When using PP, set reshard_after_forward=False in apply_fsdp2 to avoid per-microbatch all-gathers. PP schedules send multiple microbatches through each stage, so keeping gathered params avoids redundant communication.
Functions
|
Create a pipeline execution schedule. |
|
Wrap a stage module in a PipelineStage for schedule execution. |
|
Build the model chunk for a specific pipeline stage. |
|
Compute which layers go to which PP stage. |
|
Extract the PP sub-mesh from a DeviceMesh. |
|
Get the pipeline parallel rank for this process. |
|
Get the pipeline parallel world size. |
Classes
A model chunk for a single pipeline stage. |
- kempnerforge.distributed.pipeline_parallel.get_pp_mesh(device_mesh)[source]¶
Extract the PP sub-mesh from a DeviceMesh.
Returns None if no ‘pp’ dimension exists.
- Parameters:
device_mesh (DeviceMesh)
- Return type:
DeviceMesh | None
- kempnerforge.distributed.pipeline_parallel.get_pp_rank(device_mesh)[source]¶
Get the pipeline parallel rank for this process.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.pipeline_parallel.get_pp_size(device_mesh)[source]¶
Get the pipeline parallel world size.
- Parameters:
device_mesh (torch.distributed.device_mesh.DeviceMesh)
- Return type:
- kempnerforge.distributed.pipeline_parallel.compute_layer_assignment(n_layers, pp_size)[source]¶
Compute which layers go to which PP stage.
Distributes layers as evenly as possible. Earlier stages get one extra layer when n_layers is not evenly divisible by pp_size.
- class kempnerforge.distributed.pipeline_parallel.PipelineStageModule[source]¶
Bases:
ModuleA model chunk for a single pipeline stage.
- Only allocates parameters needed for this stage:
Stage 0: token_embedding + assigned transformer layers
Middle stages: assigned transformer layers only
Last stage: assigned transformer layers + final norm + output head
Layer keys in self.layers match the full Transformer (e.g. “4”, “5”, “6”) to maintain DCP checkpoint compatibility.
- init_weights_and_freqs()[source]¶
Initialize weights and RoPE frequencies after meta-device materialization.
- Return type:
None
- forward(x)[source]¶
Forward pass for this pipeline stage.
- Parameters:
x (torch.Tensor) – For stage 0: token IDs of shape (batch, seq_len). For other stages: hidden states of shape (batch, seq_len, dim).
- Returns:
logits of shape (batch, seq_len, vocab_size). For other stages: hidden states of shape (batch, seq_len, dim).
- Return type:
For last stage
- kempnerforge.distributed.pipeline_parallel.build_stage_module(config, pp_rank, pp_size)[source]¶
Build the model chunk for a specific pipeline stage.
- Parameters:
config (ModelConfig) – Model configuration.
pp_rank (int) – This process’s pipeline rank (0-indexed).
pp_size (int) – Total number of pipeline stages.
- Returns:
A PipelineStageModule containing only the parameters for this stage.
- Return type:
- kempnerforge.distributed.pipeline_parallel.build_pipeline_stage(stage_module, device_mesh, device, batch_size, seq_len, param_dtype=torch.bfloat16)[source]¶
Wrap a stage module in a PipelineStage for schedule execution.
- Parameters:
stage_module (PipelineStageModule) – The model chunk for this stage.
device_mesh (torch.distributed.device_mesh.DeviceMesh) – Full DeviceMesh with a ‘pp’ dimension.
device (torch.device) – Device for this stage.
batch_size (int) – Micro-batch size (for shape inference).
seq_len (int) – Sequence length (for shape inference).
param_dtype (torch.dtype) – Dtype for intermediate activations (matches mixed precision).
- Returns:
A PipelineStage ready for use with a pipeline schedule.
- Return type:
torch.distributed.pipelining.PipelineStage
- kempnerforge.distributed.pipeline_parallel.build_pipeline_schedule(stage, n_microbatches, loss_fn, schedule='1f1b')[source]¶
Create a pipeline execution schedule.
- Parameters:
stage (torch.distributed.pipelining.PipelineStage) – The PipelineStage for this rank.
n_microbatches (int) – Number of microbatches per training step. Must be >= pp_size for 1F1B to fill the pipeline.
loss_fn (callable) – Loss function (applied on last stage only).
schedule (str) – Schedule type — “1f1b”, “gpipe”, or “interleaved_1f1b”. Note: “interleaved_1f1b” requires multiple stages per rank (virtual pipeline stages). Pass a list of stages instead.
- Returns:
First stage: schedule.step(input_tensor, target=labels)
Other stages: schedule.step(target=labels) or schedule.step()
- Return type:
A pipeline schedule object. Call schedule.step() in the training loop