-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AllReduce distributed strategy design #373
Open
QiJune
wants to merge
5
commits into
develop
Choose a base branch
from
allreduce_design
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# AllReduce | ||
|
||
## Introduction | ||
|
||
Data parallelism enables distributed training by communicating gradients | ||
before the optimizer step to make sure that parameters of all model replicas | ||
are updated using exactly the same set of gradients, | ||
and hence model replicas can stay consistent across iterations. | ||
|
||
AllReduce is a common strategy for communicating gradients in data parallelism. | ||
Following is the pseudocode describing the training procedures | ||
under AllReduce strategy. | ||
|
||
```python | ||
broadcast(parameters, rank=0) | ||
while True: | ||
load_minibatch() | ||
forward() | ||
backward() | ||
allreduce(gradients) | ||
update() | ||
``` | ||
|
||
First, we broadcast model parameters of rank 0 to other processes. | ||
Each process loads a minibatch of training data, | ||
does forward/backward computation, and gets the gradients. | ||
We launch AllReduce to communicate gradients among the processes. | ||
At last, we update the parameters in each process individually. | ||
|
||
## AllReduce in PyTorch | ||
|
||
Before discussing how to support AllReduce in GoTorch, | ||
it's necessary to have a thorough suvey on | ||
the current implementation of AllReduce in PyTorch. | ||
|
||
PyTorch offers several tools to facilitate distributed training, | ||
including [DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel) | ||
for single-process multi-thread data parallel training | ||
using multiple GPUs on the same machine, | ||
[DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) | ||
for multi-process data parallel training | ||
across GPUs and machines. | ||
|
||
Single-process multi-GPU is not the recommended mode, | ||
becase of its overhead of scatter/gather and GIL contention in every forward pass. | ||
So, let's focus on DistributedDataParallel. | ||
|
||
### Collective Communication Library | ||
|
||
PyTorch could use different collective communication libraries as the backend, | ||
including [NCCL](https://developer.nvidia.com/nccl) and [Gloo](https://github.com/facebookincubator/gloo). | ||
NCCL supports GPU, while Gloo supports both CPU and GPU. | ||
The performance on GPU of NCCL is better than Gloo. | ||
So we use NCCL in GPU training, and Gloo in CPU training. | ||
|
||
Besides, PyTorch provides a library, [c10d](https://github.com/pytorch/pytorch/tree/master/torch/lib/c10d), | ||
which wrappers NCCL/Gloo, to manipulate `torch::Tensor` directly. | ||
It brings much convenience. | ||
|
||
Following is an example: | ||
|
||
```cpp | ||
#include <c10d/FileStore.hpp> | ||
#include <c10d/ProcessGroupGloo.hpp> | ||
|
||
using namespace ::c10d; | ||
|
||
int main(int argc, char** argv) { | ||
int rank = atoi(getenv("RANK")); | ||
int size = atoi(getenv("SIZE")); | ||
auto store = std::make_shared<FileStore>("/tmp/c10d_example", size); | ||
ProcessGroupGloo pg(store, rank, size); | ||
|
||
// Create some tensors | ||
const auto ntensors = 10; | ||
std::vector<at::Tensor> tensors; | ||
for (auto i = 0; i < ntensors; i++) { | ||
auto x = | ||
at::ones({1000, 16 * (i + 1)}, at::TensorOptions(at::CPU(at::kFloat))); | ||
tensors.push_back(x); | ||
} | ||
|
||
// Kick off work | ||
std::vector<std::shared_ptr<ProcessGroup::Work>> pending; | ||
for (auto i = 0; i < ntensors; i++) { | ||
std::vector<at::Tensor> tmp = {tensors[i]}; | ||
pending.push_back(pg.allreduce(tmp)); | ||
} | ||
|
||
// Wait for work to complete | ||
for (auto& work : pending) { | ||
work->wait(); | ||
} | ||
} | ||
``` | ||
|
||
### DistributedSampler | ||
|
||
The training samples are partitioned statically in distributed training of PyTorch. | ||
The [DistributedSampler](https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler) | ||
generates a sequence of indices of training samples for each training process. | ||
Then, each process loads a subset samples by the indices. | ||
|
||
**Note:** The dataset is assumed to be of constant size. | ||
|
||
### Launch Utility | ||
|
||
The `torch.distributed` package provides a launch utility in | ||
[torch.distributed.launch](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py). | ||
This helper utility can be used to | ||
launch multiple processes per node for distributed training. | ||
If the utility is used for GPU training, | ||
each distributed process will be operating on a single GPU. | ||
|
||
### Optimization | ||
|
||
The naive implementation of training procedures | ||
in [Introduction](#Introduction) section has two performance concerns: | ||
|
||
- Collective communication performs poorly on | ||
small tensors, which will be especially prominent on large models | ||
with massive numbers of small parameters. | ||
- Separating gradient computation and synchronization forfeits the opportunity | ||
to overlap computation with communication due to the hard boundary in between. | ||
|
||
PyTorch does more optimizations to solve these two problems: | ||
|
||
- Bucketing gradients to reduce AllReduce kernels overhead. | ||
- Registering AllReduce kernels as autograd hooks | ||
to overlap communication and computation. | ||
|
||
For more details, please refer to the [paper](https://arxiv.org/abs/2006.15704). | ||
|
||
## AllReduce in GoTorch | ||
|
||
We plan to implement the functionalities of | ||
DistributedDataParallel gradually in GoTorch. | ||
At stage 1, we provide a naive solution. | ||
An MNIST distributed example is the target in this stage. | ||
At stage 2, we will provide an optimized solution. | ||
Bucketing gradients and registering hooks will be implemented at this stage. | ||
|
||
### RecordIODataLoader | ||
|
||
The RecordIO format is a simple format for a sequence of binary records. | ||
It provides a way to seek the beginning of any record in a file. | ||
We could partition the RecordIO data and assgin to training processes. | ||
At stage 1, we support static sharding only. | ||
Following are the steps of static sharding in distributed training: | ||
|
||
1. Convert samples into RecordIO format. | ||
1. Partition records into several tasks. Each task contains | ||
one or more `{file, start_idx, end_idx}` tuples. | ||
1. Shuffle tasks and assign a subset of tasks to a training process. | ||
1. Decode records in tasks and feed to the neural network. | ||
|
||
### Go Wrapper of c10d Library | ||
|
||
[ProcessGroupNCCL](https://github.com/pytorch/pytorch/blob/master/torch/lib/c10d/ProcessGroupNCCL.hpp) | ||
implements NCCL bindings for c10d library. | ||
After adding a Go wrapper of this class, | ||
we could do allreduce on torch tensors in Go. | ||
|
||
### Go Launch Utility | ||
|
||
Go provides [os/exec](https://golang.org/pkg/os/exec/) library to spawn processes. | ||
|
||
### Optimization at Stage 2 | ||
|
||
TBD | ||
|
||
## Reference | ||
|
||
- <https://pytorch.org/docs> | ||
- <https://arxiv.org/abs/2006.15704> |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GoTorch does not have GIL, does
single-process multi-GPU
mode fits GoTorch?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The answer is no. There are two reasons:
The overhead of scatter/gather is also nonnegligible. We once use scatter/parallel do/gather to support multi-GPU AllReduce in Paddle with C++. From the experience, the speedup ratio is not very good.
We could only use scatter/gather in multi-GPU of one node. It could not be scaled to multi-node multi-GPU.