back to journal

Pipeline Parallelism

distributed training via model partitioning

At some point, a model stops fitting in a single GPU's memory and you have to split it across multiple GPUs. BLOOM, Huggingface's 175B parameter Transformer, requires ~350GB just to store weights in bfloat16. A100 80GB GPUs hold roughly a quarter of that. Training on 384 GPUs required spreading different layers across machines — a technique called model partitioning or pipeline parallelism.

The naive version of this is embarrassingly bad at GPU utilization. If GPU2 can't start until GPU1 finishes, you've essentially serialized your compute across machines that are supposed to run in parallel. This post walks through three approaches — naive, GPipe, and PipeDream — and explains the concrete tradeoffs between GPU utilization, memory consumption, and mathematical equivalence to single-GPU training. Along the way I'll highlight the parts that didn't fully click for me until I traced through the actual scheduling logic.

Naive Model Parallelism

The most straightforward implementation: split the model layers into contiguous groups, assign each group to a GPU, and run training one minibatch at a time. For a 4-layer model split across 2 GPUs: GPU1 runs layers L1–L2, produces the intermediate activation tensor, and MPI-sends it to GPU2. GPU2 runs L3–L4, computes the loss, and begins the backward pass. At the L2→L3 boundary in the backward direction, GPU2 sends its input gradients back to GPU1, which finishes backprop. Gradient updates happen locally on each GPU.

This approach is bit-exact with single-GPU training — same math, same numerics. The communication is point-to-point (MPI.Send/MPI.Recv), not collective, so there's no need for AllReduce or broadcast primitives. Simple and correct. But watch the pebble graph below and you'll immediately see the problem:

Pebble graph illustrating naive model parallelism with GPU1 forward caching and MPI communication
Naive model parallelism: GPU1 runs forward and waits while GPU2 runs backward. One GPU is always idle.

Three problems are immediately visible:

  • GPU utilization is 1/n: at any given moment, exactly one GPU is computing and the rest are idle. With 8 pipeline stages, each GPU is productive only 12.5% of the time.
  • Communication and computation don't overlap: while the activation tensor is in flight over the network, no GPU does useful work. The interconnect stall is dead time.
  • Memory blowup on early stages: GPU1 must cache all forward-pass activations for the entire minibatch until the backward pass reaches it — potentially gigabytes of intermediate tensors kept alive for the full forward+backward duration.

GPipe: Microbatches and Gradient Accumulation

GPipe's core idea: split each minibatch into m equal-sized microbatches, process them sequentially through the pipeline, and accumulate gradients before applying the optimizer step. The math is exact: the gradient of a sum is the sum of the gradients, so summing microbatch gradients gives you the same gradient estimate as processing the full minibatch at once. This is called gradient accumulation, and it's mathematically bit-equivalent to single-GPU training.

The key win: while GPU2 is running the forward pass for microbatch 2, GPU1 can already start the forward pass for microbatch 3. Multiple microbatches are in flight simultaneously, keeping more GPUs busy at once.

GPipe: Interleaving and Its Limits

Sketch of interleaved GPipe showing dependency arrows
GPipe with interleaving: dependency arrows show which microbatch results each GPU is waiting on.

In practice, the interleaving of communication and computation is limited. A GPU can't start processing microbatch i until the previous stage has finished and transmitted its output. If all stages take the same time, you get a clean pipeline — but you still get startup and teardown overhead where some GPUs are idle. These idle slots are called pipeline bubbles.

GPipe: Pipeline Bubbles

A bubble is idle time in the pipeline caused by a data dependency that hasn't resolved yet. GPU4 can't run forward pass for microbatch 1 until GPU3 has finished it and sent the activation. Similarly in the backward pass, each stage has to wait for the next stage to send back its input gradient. The fraction of time wasted in bubbles is a function of pipeline depth n (number of GPU stages) and number of microbatches m:

bubble fraction=1mm+n1\text{bubble fraction} = 1 - \frac{m}{m + n - 1}

As m → ∞, the bubble fraction → 0. In practice, m ≈ 4n is a common target (bubble fraction ≈ 20%). The tradeoff: larger m means larger total batch size, which requires learning rate scaling (linear scaling rule) and increases the amount of activation memory you're caching. There's no free lunch — you're trading memory pressure for utilization.

Demonstration of pipeline bubble inefficiencies caused by data dependencies
Pipeline bubbles: GPUs sit idle waiting for the previous stage to finish.
Example calculations comparing single vs 4-microbatch bubble fractions
Bubble fraction comparison: 4 microbatches cuts wasted time dramatically vs. 1 microbatch.

GPipe: Memory and Gradient Checkpointing

The memory problem in GPipe is stark: all m microbatch activations are in flight simultaneously during the all-forward phase. Each GPU must cache activations from the time a microbatch was forwarded until the corresponding backward reaches it. For m=8 microbatches on a 4-GPU pipeline, GPU1 holds 8 microbatches worth of activations simultaneously — that's 8× the activation memory of a single-GPU run.

GPipe's solution: gradient checkpointing (also called activation recomputation). Instead of caching all intermediate activations, cache only the inputs at pipeline-stage boundaries and recompute activations on the fly during the backward pass. This trades compute for memory. Without gradient checkpointing, peak memory per GPU is O(batchsize × layers_per_gpu). With it:

O(batchsize+#total layers#GPUsbatchsize#microbatches)O\left(\text{batchsize} + \frac{\#\text{total layers}}{\#\text{GPUs}} \cdot \frac{\text{batchsize}}{\#\text{microbatches}}\right)
Memory state during backward pass with gradient checkpointing
Gradient checkpointing: only boundary inputs are cached. Activations are recomputed during backward, adding ~33% compute overhead.

The 33% compute overhead from recomputation is usually worth it for large models — it's often the only way to fit a model in GPU memory at all. PyTorch's `torch.utils.checkpoint.checkpoint()` and Megatron-LM's activation recomputation are both GPipe-style gradient checkpointing. If you've used either, you've used this idea.

PipeDream: 1F1B and Earlier Backward Passes

PipeDream's key insight: you don't have to wait until all microbatches have been forwarded before starting any backward passes. As soon as the last stage completes the forward pass for microbatch 1, it can immediately start the backward pass for microbatch 1 — even while earlier stages are still processing microbatches 2, 3, 4... This is the 1F1B (one forward, one backward) pattern: in steady state, each GPU alternates between a forward pass for a new microbatch and a backward pass for an older one.

PipeDream schedule with 4 GPUs and 8 microbatches showing 1F1B pattern
PipeDream 1F1B schedule: blue = forward, green = backward. Numbered by microbatch ID.

The memory benefit is substantial. In GPipe, all m microbatches are in flight during the all-forward phase, so you need activation memory proportional to m. In PipeDream's steady state, GPU1 starts a backward as soon as it finishes a forward — so at most n microbatches are in flight simultaneously (where n is pipeline depth). For both algorithms, activation memory without gradient checkpointing is:

O(#max microbatches in flightmicrobatch-size#total layers#GPUs)O\left(\#\text{max microbatches in flight} \cdot \text{microbatch-size} \cdot \frac{\#\text{total layers}}{\#\text{GPUs}}\right)

The max-microbatches-in-flight term is where GPipe and PipeDream differ: GPipe's all-forward phase puts all m microbatches in flight; PipeDream's 1F1B schedule keeps at most n in flight (the pipeline depth). Look at GPU1 in the diagram during steady state — it alternates F and B, never starting a new forward without completing a backward first.

PipeDream steady state showing GPU1 alternating forward and backward passes
PipeDream steady state: GPU1 alternates 1F1B after the warmup phase, keeping at most n microbatches in flight.

In the example above (4 GPUs, 8 microbatches): PipeDream has at most 4 microbatches in flight, GPipe has 8. PipeDream halves the activation memory overhead. The bubble fraction is identical between the two algorithms — that's determined by the pipeline structure (n stages, m microbatches), not by when backwards start. You can verify this visually: take the PipeDream schedule and slide all the backward passes to the right (consolidating them after all forwards complete) and you recover GPipe. Same total time, different activation memory profile.

Communication Volume: Pipeline vs Data Parallelism

For a model with dense layers of hidden dimension N, each pipeline stage boundary sends activations of size (microbatch_size × N) forward and gradients of the same size backward. Total pipeline communication per minibatch: (n-1) × 2 × batchsize × N floats — it scales with pipeline depth and activation size, not parameter count. Data parallelism (Ring AllReduce) transfers roughly 2 × (total_params / n_gpus) floats per step — scales with model size, not activations. For very large models with small activations, pipeline parallelism can be cheaper to communicate. For models with large activations (vision, long-context language models), it can be more expensive. The other critical difference: data-parallel AllReduce overlaps with the backward pass naturally; pipeline-parallel point-to-point transfers are on the critical path and harder to hide.

Combining Pipeline and Data Parallelism

Pipeline and data parallelism are orthogonal — you can use both simultaneously. In a combined setup, you run multiple pipeline *replicas* (data parallelism across replicas) where each replica is itself a pipeline (pipeline parallelism across stages). The constraint: your effective batch size is (microbatch_size × n_microbatches × n_data_parallel_replicas), so you need a large enough batch to keep both dimensions busy without gradient noise from tiny microbatches.

Illustration of orthogonal communication partners in combined data and pipeline parallelism
Combined DP + PP: each GPU participates in two communicators — one for pipeline neighbors, one for data-parallel peers.

The implementation uses MPI Communicators — subgroups of GPUs that only communicate within the group. Each GPU belongs to two: one for its pipeline stage peers (all GPUs with the same layer slice, for AllReduce), and one for its pipeline neighbors (the stages before and after, for point-to-point activations and gradients). These communicators partition the GPU cluster into a 2D grid: pipeline depth × data-parallel width. DeepSpeed, Megatron-LM, and FairScale all implement this pattern. In practice, large training runs often use a 3rd dimension — tensor parallelism within each layer — giving a 3D parallelism grid: pipeline × data × tensor.

Implementation: GPipe in Python

Unlike data parallelism (which requires AllReduce — a collective operation requiring coordination among all workers), pipeline parallelism uses only point-to-point sends and receives between adjacent stages. This means each GPU can follow a simple, static schedule without global synchronization. DeepSpeed's pipeline engine uses exactly this design: one worker per GPU, executing a sequence of commands determined before the minibatch starts.

def minibatch_steps(self):
    yield [ZeroGrad()]

    # STAGE 1: First, we FWD all microbatches
    for microbatch_id in range(self.num_micro_batches):
        yield self.steps_FWD_microbatch(microbatch_id)

    # at this position, all microbatches are in flight and
    # memory demand is highest

    # STAGE 2: Then, we BWD all microbatches
    for microbatch_id in reversed(range(self.num_micro_batches)):
        yield from self.steps_BWD_microbatch(microbatch_id)

    # updating the weights is the last step of processing any batch
    yield [OptimizerStep()]

This is the GPipe schedule: all forwards first, then all backwards in reverse order. The comment at peak memory is key — between the last FWD and the first BWD, every microbatch's activations are live simultaneously. For the forward pass of each microbatch:

def steps_FWD_microbatch(self, microbatch_id):
    cmds = []
    if self.is_first_stage:
        # first pipeline stage loads data from disk
        cmds.append(LoadMicroBatchInput(microbatch_id=microbatch_id))
    else:
        # all other stages receive activations from prev pipeline stage
        cmds.append(RecvActivations())

    cmds.append(Forward(microbatch_id=microbatch_id))

    if not self.is_last_stage:
        # all but the last pipeline stage send their output to next stage
        cmds.append(SendActivations())
    return cmds

Load input (or receive activations), run the forward pass, send activations to the next stage. Clean and self-contained. The backward pass is symmetric but runs in reverse — the last stage has the loss, so it loads targets instead of receiving gradients. A notable detail: the `BackwardGradAllReduce` on the first microbatch (processed last in backward order) overlaps the gradient AllReduce with the actual backward computation, hiding some of the data-parallel communication cost:

def steps_BWD_microbatch(self, microbatch_id):
    cmds = []
    if self.is_last_stage:
        # last pipeline stage loads targets from disk
        cmds.append(LoadMicroBatchTarget(microbatch_id=microbatch_id))
    else:
        # all other stages wait to receive grad from next stage
        cmds.append(RecvOutputGrad())

    if self.is_first_microbatch(microbatch_id):
        # interleaved backprop and AllReduce during last microBatch of BWD
        cmds.append(BackwardGradAllReduce(microbatch_id=microbatch_id))
    else:
        cmds.append(BackwardGradAcc(microbatch_id=microbatch_id))

    if not self.is_first_stage:
        # all but last pipeline stage send their input grad to prev stage
        cmds.append(SendInputGrad())
    yield cmds

The `BackwardGradAcc` vs `BackwardGradAllReduce` distinction is subtle but important. For all but the last microbatch (in backward order), we accumulate gradients locally without synchronizing with other data-parallel replicas. Only on the last backward do we AllReduce — and by launching it as a background NCCL operation overlapping with the final backprop, we hide part of the network round-trip latency.

Hardware Context: Interconnects and Scaling

Hardware hierarchy showing multi-node GPU clusters with PCIe, NVLink, InfiniBand
Distributed training hardware: NVLink at ~900GB/s within a node; InfiniBand HDR at ~200Gbps between nodes.

The bandwidth numbers matter a lot for pipeline parallelism design. NVLink bandwidth within a node is ~900GB/s bidirectional — fast enough that intra-node pipeline stages are almost never bandwidth-limited. Cross-node InfiniBand is 25–200Gbps depending on generation — easily 10-100× slower. A good rule of thumb: place pipeline stage boundaries at intra-node boundaries where possible, and use pipeline parallelism to handle the inter-node communication that you can't avoid.

Visual comparison of strong vs weak scaling strategies
Strong vs weak scaling: strong = fixed problem size across more GPUs; weak = fixed per-GPU workload.

Pipeline parallelism is a form of weak scaling for model size: each GPU holds a fixed number of layers, and you scale up the total parameter count by adding more GPUs to the pipeline. Bubble overhead is independent of model size (it's determined by pipeline depth and microbatch count). This is why pipeline parallelism is a first-class citizen in the infrastructure of LLM training — adding more layers doesn't increase utilization loss.

Conclusion

Pipeline parallelism is ultimately about scheduling. Given that a model must be split across GPUs, the question is: in what order do you run forward and backward passes across microbatches to maximize utilization and minimize memory? Naive MP gives you correctness but terrible utilization. GPipe restores utilization with microbatches but blows up activation memory. PipeDream halves the activation memory with 1F1B while maintaining the same bubble fraction. None of these choices are free — every improvement comes with a cost somewhere else.

What I find most interesting about this space is the interaction between pipeline depth and batch size. Deep pipelines (more stages) have higher bubble overhead unless you increase microbatch count, which increases batch size, which requires learning rate scaling. At some point you're constrained by convergence — very large batches don't generalize as well without careful warmup and decay schedules. The scheduling algorithm and the optimization algorithm aren't actually independent. Modern large-scale training infrastructure (Megatron-LM, DeepSpeed, FairScale) has to co-design both.

If you're building or debugging a multi-GPU training setup, the most common failure mode I've seen is incorrect gradient accumulation — treating microbatch gradients as independent updates instead of accumulating them before the optimizer step. Always validate against single-GPU training numerics before debugging performance. Correctness first, then throughput.

Original article by Simon Boehm

Reproduced with permission from the author.