Skip to content

Commit

Permalink
custom reduce scatter
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#763

Piggyback on the twoshot allreduce for the reducescatter - pretty much the first half of twoshot allreduce.

Differential Revision: D69364062
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Feb 13, 2025
1 parent 1b7789a commit 2a3965c
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 69 deletions.
33 changes: 31 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ void nccl_alltoall(
torch::cuda::nccl::all2all(dsts, srcs, *get_nccl_comm(comm_idx), stream);
}

void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
void nccl_reducescatter(
at::Tensor dst,
at::Tensor src,
std::optional<at::Tensor> bias,
int64_t comm_idx) {
using namespace c10d;
TORCH_CHECK(src.is_contiguous());
TORCH_CHECK(dst.is_contiguous());
Expand All @@ -194,6 +198,10 @@ void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
*get_nccl_comm(comm_idx),
at::cuda::getCurrentCUDAStream()),
"ncclReduceScatter");

if (bias) {
dst.add_(*bias);
}
}

void nccl_allreduce(
Expand Down Expand Up @@ -259,6 +267,11 @@ void two_shot_car_allreduce(
at::Tensor src,
std::optional<at::Tensor> bias,
int64_t comm_idx);
void car_reduce_scatter(
at::Tensor dst,
at::Tensor src,
std::optional<at::Tensor> bias,
int64_t comm_idx);

at::Tensor car_tensor();

Expand All @@ -282,7 +295,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"nccl_alltoall_single(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()");
m.def("nccl_alltoall(Tensor(a!)[] dst, Tensor[] src, int comm_idx=0) -> ()");

m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");
m.def(
"nccl_reducescatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");

m.def(
"nccl_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
Expand All @@ -302,6 +316,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {

m.def(
"two_shot_car_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");

m.def(
"car_reduce_scatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
Expand All @@ -312,6 +329,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("nccl_reducescatter", nccl_reducescatter);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
m.impl("car_reduce_scatter", car_reduce_scatter);
}

// Though it shouldnt be used, it is useful to define these functions for CPU to
Expand All @@ -324,6 +342,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("nccl_reducescatter", nccl_reducescatter);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
m.impl("car_reduce_scatter", car_reduce_scatter);
}

// Shape registration functions for car operators.
Expand Down Expand Up @@ -360,6 +379,7 @@ void nccl_alltoall_meta(
void nccl_reducescatter_meta(
at::Tensor /* dst */,
at::Tensor /* src */,
std::optional<at::Tensor> bias,
int64_t /* comm_idx */) {
return;
}
Expand All @@ -380,6 +400,14 @@ void two_shot_car_allreduce_meta(
return;
}

void car_reduce_scatter_meta(
at::Tensor /* dst */,
at::Tensor /* src */,
std::optional<at::Tensor> /* bias */,
int64_t /* comm_idx */) {
return;
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("nccl_allreduce", nccl_allreduce_meta);
m.impl("nccl_allgather", nccl_allgather_meta);
Expand All @@ -388,6 +416,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("nccl_reducescatter", nccl_reducescatter_meta);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);
m.impl("car_reduce_scatter", car_reduce_scatter_meta);
}

} // namespace fbgemm_gpu
118 changes: 112 additions & 6 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
#endif
}

template <int32_t kWorldSize, bool has_acc>
template <int32_t kWorldSize, bool has_acc, bool reduce_scatter>
#if defined(USE_ROCM)
__launch_bounds__(512) __global__ void two_shot_all_reduce(
#else
Expand Down Expand Up @@ -425,13 +425,18 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
}

// Store to the local buffer.
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
*reinterpret_cast<const uint4*>(&sums);
if constexpr (reduce_scatter) {
*reinterpret_cast<uint4*>(&output[i]) =
*reinterpret_cast<const uint4*>(&sums);
} else {
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
*reinterpret_cast<const uint4*>(&sums);
}
}

__syncthreads();

// barreris among the blocks with the same idx (release-acuqire semantics)
// barriers among the blocks with the same idx (release-acuqire semantics)
if (threadIdx.x < kWorldSize) {
// The all blocks notifies the other ranks.
int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize;
Expand All @@ -445,6 +450,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
} while (rank_barrier != flag);
}

if constexpr (reduce_scatter) {
// reduce scatter we can stop here and skip the allgather below
return;
}

__syncthreads();

// Gather all needed elts from other intra-node ranks
Expand Down Expand Up @@ -628,7 +638,7 @@ void two_shot_car_allreduce(
#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
if (z) { \
two_shot_all_reduce<kWorldSize, true> \
two_shot_all_reduce<kWorldSize, true, false> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
Expand All @@ -641,7 +651,7 @@ void two_shot_car_allreduce(
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} else { \
two_shot_all_reduce<kWorldSize, false> \
two_shot_all_reduce<kWorldSize, false, false> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
Expand All @@ -667,4 +677,100 @@ void two_shot_car_allreduce(
return;
}

void car_reduce_scatter(
at::Tensor y_reducescatter,
at::Tensor y,
std::optional<at::Tensor> z,
int64_t comm_idx) { // match the API with nccl_allreduce in
// https://fburl.com/code/v538vig9
auto state = get_car_state();
c10::cuda::CUDAGuard gg(y_reducescatter.device());
TORCH_CHECK(y_reducescatter.is_contiguous());
TORCH_CHECK(y.is_contiguous());
TORCH_CHECK((state->world_size_ * y_reducescatter.numel()) == y.numel());
TORCH_CHECK(y.numel() % 8 == 0);
TORCH_CHECK(y.numel() < kMaxCAR);
const auto N = y.numel();
if (z) {
TORCH_CHECK(z->numel() == y.numel());
}
++state->flag_;

std::array<at::BFloat16*, 8> inputs;
for (auto ii = 0; ii < state->world_size_; ++ii) {
inputs[ii] = state->buffers_[ii].data_ptr<at::BFloat16>();
}

std::array<int32_t*, 8> barriers;
for (auto ii = 0; ii < state->world_size_; ++ii) {
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
}

AT_CUDA_CHECK(cudaMemcpyAsync(
inputs[state->rank_],
y.data_ptr<at::BFloat16>(),
y.numel() * y.element_size(),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));

constexpr int32_t N_per_thread = 8;
TORCH_CHECK(N % state->world_size_ == 0);
const auto N_per_rank = N / state->world_size_;

TORCH_CHECK(N_per_rank % N_per_thread == 0);
auto threads_per_rank = div_round_up(N_per_rank, N_per_thread);

#if defined(USE_ROCM)
constexpr int32_t kThreadsPerBlock = 512;
#else
constexpr int32_t kThreadsPerBlock = 1024;
#endif

constexpr int32_t kMaxBlocks = 24;

auto blocks = std::min<int32_t>(
cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks);

#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
if (z) { \
two_shot_all_reduce<kWorldSize, true, true> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
z->data_ptr<at::BFloat16>(), \
y_reducescatter.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} else { \
two_shot_all_reduce<kWorldSize, false, true> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
nullptr, \
y_reducescatter.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} \
}

TORCH_CHECK(
state->world_size_ == 2 || state->world_size_ == 4 ||
state->world_size_ == 8);
X(2);
X(4);
X(8);

#undef X
return;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit 2a3965c

Please sign in to comment.