Checkpointing is the process by which all the model state during training is saved to durable storage so that training can be stopped and resumed at a later time.
The capacity required for LLM checkpointing scale with the size of the model, not the size of the training cluster.1 Model checkpoint sizes can be approximated by assuming 16 bytes per parameter (see LLM training memory requirements).
Similarly, the performance required for checkpointing scales with model size instead of cluster size.
The naïve way to checkpoint is to write all model weights and activations from one data parallel (DP) replica to shared storage after weights are synchronized. Nobody does this in practice though, because it’s very slow.
Distributed checkpointing
Distributed checkpointing is the case where each data parallel replica writes a non-overlapping part of the model weights and activations after the sync. This is analogous to how MPI-IO does collective writes to maximize the aggregate bandwidth of all clients to shared storage.
Multilevel/asynchronous checkpointing
Asynchronous checkpointing is common at scale. In this scheme, the GPU is only blocked when copying checkpoint data from GPU memory to host CPU memory. The GPU then proceeds with computing while the CPU asynchronously flushes the checkpoint data to nonvolatile storage. ByteDance’s MegaScale does this.2
There may be multiple levels of nonvolatile storage that are used at different intervals to further reduce bandwidth requirements and increase durability of the checkpoints.
Microsoft’s Nebula framework does this by
- Synchronously checkpointing to CPU memory
- Asynchronously copying that to an adjacent node’s local SSD to protect against a single-node failure
- Asynchronously copying that to object storage to protect against multi-node failures, allow rollback, and store checkpoints long-term
NVIDIA’s NeMo Framework also supports asynchronous checkpointing.
PyTorch also supports asynchronous checkpointing. It sounds very similar to what NeMo has implemented.
In practice
ByteDance uses HDFS for checkpointing.2
DeepSeek uses asynchronous checkpointing from HBM to DRAM, then drains out to 3FS asynchronously.3