Context parallelism is a technique whereby the input of an LLM query is partitioned along the sequence dimension and distributed over multiple GPUs. So for example, if you ask an LLM this:

What is the meaning of life?

This may be chopped up in a way such that…

  • What goes to GPU0
  • is the goes to GPU1
  • meaning of goes to GPU2
  • life goes to GPU3

When attention is computed (during either training or inference), key and value shards can get passed around in a ring (this is Ring attention).

The benefit here is that you can spread very large contexts (and their keys and values) across the HBM of many GPUs and the prefill time decreases since every additional GPU brings FLOPS as well as HBM capacity.

The downside is that there is much more communication for a single forward pass. And GPUs are very expensive.