Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpoint] feat: open source fast checkpoint system #38

Merged
merged 1 commit into from
May 31, 2024

Conversation

MingjiHan99
Copy link
Collaborator

@MingjiHan99 MingjiHan99 commented May 31, 2024

Summary

We improved vescale.checkpoint with the following new features for fast checkpointing (where front three features are built-in techniques without necessitating manual activation):

  • Saving Plan Caching: During training, the program may save model and optimizer checkpoints every n steps. Once a saving plan is created, it remains unchanged as long as the model does. We implemented plan caching to avoid regenerating the plan when checkpointing a model or optimizer multiple times, reducing unnecessary compute and communication costs. As of 05/30/2024, PyTorch DCP does not support plan caching.

  • Saving Plan Load-Balancing: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. Existing PyTorch DCP (as of 05/30/2024) deduplicates replicated tensors using a simple algorithm, causing GPUs with data parallel rank 0 to save the entire model, leading to load imbalance. We implemented a load-balancing algorithm to address this issue when deduplicating model tensors.

  • D2H Tensor Copying via Pinned Memory: When copying tensors from GPU to host memory, vescale.checkpoint uses pinned host memory, reducing memory allocation costs each time a checkpoint is saved. As of 05/30/2024, PyTorch DCP does not support pinned memory.

  • Checkpoint Broadcasting: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. If broadcast_checkpoint is enabled, vescale.checkpoint.load lets GPUs with data parallel rank 0 to load the model and broadcast it to other GPUs with higher data parallel ranks. If GPUs are connected with NCCL and I/O bandwidth is fully utilized, broadcasting model tensors speeds up checkpoint loading compared to all GPUs loading models from persistent storage. E.g.:

    # prepare checkpoint state for the model and optimizer
    checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
    # load the checkpoint
    vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state, broadcast_checkpoint=True)
  • Asynchronous Checkpointing: When vescale.checkpoint.save is called, it first generates a saving plan and then synchronously copies tensors from GPU to host memory. If async_checkpoint is enabled, the training program can continue after the D2H copying, while vescale.checkpoint.save continues to serialize tensors and dump the checkpoint to persistent storage asynchronously without blocking training. As of 05/30/2024, PyTorch DCP does not support asynchronous checkpointing. E.g.:

    # prepare checkpoint state for the model and optimizer
    checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
    # save the checkpoint asynchronuously
    vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state, async_checkpoint=True)

Acknowledgement

We sincerely appreciate all contributors including but not limited to @shanesyy-1992 @raywan-110 @lazychao @AHEADer @MingjiHan99

@shanesyy-1992
Copy link

From my understanding, Checkpoint Broadcasting might be beneficial only when the storage throughput is limited under certain circumstances. Maybe it's better to add some more guidance on when to use this feature.

@MingjiHan99 MingjiHan99 merged commit c4afc72 into main May 31, 2024
1 check passed
@MingjiHan99 MingjiHan99 deleted the opensource_053024 branch May 31, 2024 07:12
@raywan-110
Copy link

Let's keep pushing forward 💪!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants