Checkpointing is the process by which all the model state during training is saved to durable storage so that training can be stopped and resumed at a later time.

The capacity required for LLM checkpointing scale with the size of the model, not the size of the training cluster.1 Model checkpoint sizes can be approximated by assuming 16 bytes per parameter (see LLM training memory requirements).

Similarly, the performance required for checkpointing scales with model size instead of cluster size.

The naïve way to checkpoint is to write all model weights and activations from one data parallel (DP) replica to shared storage after weights are synchronized. Nobody does this in practice though, because it’s very slow.

Distributed checkpointing

Distributed checkpointing is the case where each data parallel replica writes a non-overlapping part of the model weights and activations after the sync. This is analogous to how MPI-IO does collective writes to maximize the aggregate bandwidth of all clients to shared storage.

Multilevel/asynchronous checkpointing

Asynchronous checkpointing is common at scale. In this scheme, the GPU is only blocked when copying checkpoint data from GPU memory to host CPU memory. The GPU then proceeds with computing while the CPU asynchronously flushes the checkpoint data to nonvolatile storage. ByteDance’s MegaScale does this.2

There may be multiple levels of nonvolatile storage that are used at different intervals to further reduce bandwidth requirements and increase durability of the checkpoints.

Microsoft’s Nebula framework does this by

  1. Synchronously checkpointing to CPU memory
  2. Asynchronously copying that to an adjacent node’s local SSD to protect against a single-node failure
  3. Asynchronously copying that to object storage to protect against multi-node failures, allow rollback, and store checkpoints long-term

NVIDIA’s NeMo Framework also supports asynchronous checkpointing.

PyTorch also supports asynchronous checkpointing. It sounds very similar to what NeMo has implemented.

In practice

ByteDance uses HDFS for checkpointing.2

DeepSeek uses asynchronous checkpointing from HBM to DRAM, then drains out to 3FS asynchronously.3

Footnotes

  1. A Checkpoint on Checkpoints in Large Language Models (vastdata.com)

  2. [2402.15627] MegaScale: Scaling Large Language Model Training to More Than 10,000 GPUs (arxiv.org) 2

  3. See 3FS