Context parallelism is a technique whereby the input of an LLM query is partitioned along the sequence dimension and distributed over multiple model replicas on multiple GPUs. So for example, if you ask an LLM this:

What is the meaning of life?

This may be chopped up in a way such that…

  • What goes to model replica 0 on GPU0
  • is the goes to model replica 1 on GPU1
  • meaning of goes to model replica 2 on GPU2
  • life goes to model replica 3 on GPU3

When attention is computed (during either training or inference), key and value shards can get passed around in a ring (this is attention > Ring attention).

The benefit here is that you can spread very large contexts (and their keys and values) across the HBM of many GPUs and the prefill time decreases since every additional GPU brings FLOPS as well as HBM capacity.

The downside is that there is much more communication for a single forward pass. And GPUs are very expensive, so context parallelism is a very costly way to solve two problems with long prompts:

Inference PhasePerformance DriverReasoning
prefillTime to first token (TTFT)Prefill is accelerated by computing across multiple model replicas concurrently.
decodeKV cache capacityKV cache capacity is distributed over more GPUs’ HBM

In practice, context parallelism is combined with other forms of parallelism. For example, consider a transformer with 120 layers running on 8-way HGX nodes:

  • 40 layers fit on a single 8-way node, so a single model replica requires 3 nodes (24 GPUs). Pipeline parallelism is used to shard the 120 layers into 3x40 layers.
  • Within each 8-way node, tensor parallelism is used.

This results in a single model replica requiring 24 GPUs (3 nodes). For long input prompts, you might apply context parallelism to distribute the sequence across two model replicas (CP=2). This then requires 48 GPUs for a single inference request.