Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] SHM based allreduce improvement for small message size #5571

Merged
merged 37 commits into from
Jun 12, 2024

Conversation

delock
Copy link
Collaborator

@delock delock commented May 27, 2024

On CPU server, when running SHM based allreduce for small messages, the performance is pretty much dominated by synchronization latency. These latency includes the following two situations:

  1. Wait for status change from other ranks.
  2. Use #pragma omp parallel for to accelerator memory bandwidth bound operations such as parallel_memcpy or reduce.

Each synchronization add a little time to allreduce latency. In current implementation, for small messages, 5 syncs on rank 0 are needed. This includes: 1) copy-in; 2) wait for other ranks done copy; 3) reduce; 4) copy-out; 5) wait for other ranks finish copy-out

We redesign the algorithm for small message allreduce (called symmetric_naive_allreduce) to have only three syncs, each rank do exactly the same steps: 1) copy-in; 2) wait for other ranks done copy; 3) reduce to output buffer directly. We use double buffer so we can skip the last wait and go directly to next call using another buffer. We have a carefully designed state check to avoid using global barrier among ranks.

Test shows for message size < 1MB, allreduce latency will reduce 30% to 50%. This is especially helpful for tensor parallel decoding with small batch size, where the tensor size is usually a few 10s of KBytes.

message size(bytes) new method latency(us) old method latency(us)
2 13.34 20.39
4 13.44 19.57
8 13.70 19.76
16 13.27 20.43
32 13.42 19.75
64 13.38 19.80
128 13.70 19.44
256 13.99 20.33
512 13.91 20.28
1024 15.00 22.86
2048 15.82 20.93
4096 16.00 21.08
8192 16.31 21.50
16384 16.27 22.95
32768 16.13 25.17
65536 18.92 25.90
131072 21.12 27.42
262144 23.09 32.36
524288 32.78 42.80

Because the new method would compute same reduce value on all ranks. Caution needs to be taken to ensure the result is identical on all ranks. We use the test in the link https://github.com/delock/ds_allreduce_bench/blob/main/ds_comm_bench.py#L70 to ensure the implementation is correct. https://github.com/delock/ds_allreduce_bench/blob/main/validate.sh is a test script for better coverage.

@delock
Copy link
Collaborator Author

delock commented Jun 6, 2024

Hi @awan-10 , this PR is ready for review, can this PR be reviewed? Thanks!

@tjruwase tjruwase requested review from adk9 and tjruwase and removed request for arashb, awan-10 and mrwyattii June 9, 2024 22:52
@tjruwase
Copy link
Contributor

tjruwase commented Jun 9, 2024

Hi @awan-10 , this PR is ready for review, can this PR be reviewed? Thanks!

@delock, we are reviewing now. Thanks for the PR!

csrc/cpu/comm/shm.cpp Outdated Show resolved Hide resolved
csrc/cpu/comm/shm.cpp Outdated Show resolved Hide resolved
parallel_memcpy(slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(workspace[rank]->buffer, chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
if (i != world_rank) { wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); }

csrc/cpu/comm/shm.cpp Outdated Show resolved Hide resolved
csrc/cpu/comm/shm.cpp Outdated Show resolved Hide resolved
@delock
Copy link
Collaborator Author

delock commented Jun 12, 2024

Hi @adk9 , code updated according to comments, thanks!

@adk9
Copy link
Contributor

adk9 commented Jun 12, 2024

Hi @adk9 , code updated according to comments, thanks!

Hi @delock, thanks for your changes! The formatting check seems to fail for this PR. Could you run pre-commit or clang-format on your changes?

@delock
Copy link
Collaborator Author

delock commented Jun 12, 2024

Hi @adk9 , code updated according to comments, thanks!

Hi @delock, thanks for your changes! The formatting check seems to fail for this PR. Could you run pre-commit or clang-format on your changes?

Hi @adk9 , formatting had been fixed in latest CI. Thanks!

@adk9 adk9 added this pull request to the merge queue Jun 12, 2024
Merged via the queue into deepspeedai:master with commit eda5075 Jun 12, 2024
12 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Jul 16, 2024
#5604)

This PR allows `deepspeed.comm.inference_all_reduce()` enters
torch.compile graph even it is implemented as C++ kernel in DeepSpeed.

Previous implementation register `inference_all_reduce()` C++ kernel as
pybind function so it can be called inside PyThon code. However pybind
function cannot be recognized by PyTorch so graph breaks when
`inference_all_reduce` is called.

We address issue by register `inference_all_reduce` as a PyTorch custom
op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into
PyTorch graph

The output trace code from torchinductor
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"):
        # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor)
        inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3)

        # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input)
        permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute);  primals_2 = permute = None

        # No stacktrace found for following nodes
        copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce);  primals_3 = None
        return [addmm, inference_all_reduce]
```

Note in this PR the inference_all_reduce op for CPU does not handle
multinode and FP16 data type. For FP16 data type support, we will align
with PyTorch CPU FP16 plan. For multinode, we are still looking at the
possibility to upstream oneCCL integration into PyTorch, so we are able
to get use of oneCCL for multinode tensor parallel inference with
PyTorch.

This PR is independent to
#5571. They can work
seperately or together without issue.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants