diff --git a/csrc/cpu/comm/shm.cpp b/csrc/cpu/comm/shm.cpp index 859c2fec292d..78170b538543 100644 --- a/csrc/cpu/comm/shm.cpp +++ b/csrc/cpu/comm/shm.cpp @@ -21,9 +21,13 @@ // states for collectives enum coll_state { coll_begin = 0, - coll_allreduce_naive__copy_in_done, // this state is for rank != 0 - coll_allreduce_naive__reduce_done, // this state is for rank == 0 - coll_allreduce_naive__copy_out_done, // this state is for rank != 0 + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, }; // SHM building blocks @@ -71,6 +75,8 @@ void shared_close(SharedData* data) } } +static int world_size; + // SHM based allreduce helper functions // buffer that holds shm name #define NAME_BUF_SIZE 1000 @@ -78,64 +84,37 @@ void shared_close(SharedData* data) #define NAIVE_ALLREDUCE_THRESHOLD 1048576 #define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" struct allreduce_workspace { - enum coll_state state; - sem_t mutex; - sem_t turnstile1; - sem_t turnstile2; - int counter; - char buffer[MAX_BUF_SIZE]; + enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for symmetric_naive_all_reduce + // after that : buffer for distributed_naive_all_reduce + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; }; -struct allreduce_workspace** workspace; -void wait_buffer_state_until(int index, enum coll_state state) -{ - volatile enum coll_state* state_ptr = &(workspace[index]->state); +#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE - while (*state_ptr != state) - ; -} +struct allreduce_workspace** workspace; + +// buffer for small messages, double buffer +char** symmetric_buffer[2]; +// buffer for large messages, double buffer +char** distributed_buffer[2]; -void wait_buffer_state_until_range(int index, enum coll_state start, int size) +void wait_buffer_state_until_2(int index, + enum coll_state state0, + enum coll_state state1, + int state_group) { - volatile enum coll_state* state_ptr = &(workspace[index]->state); - enum coll_state end = (enum coll_state)(start + size); + volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]); while (1) { volatile enum coll_state cur_state = *state_ptr; - if (cur_state >= start and cur_state < end) break; + if (cur_state == state0 || cur_state == state1) break; } } -void wait_buffer_state_until_not(int index, enum coll_state state) -{ - volatile enum coll_state* state_ptr = &(workspace[index]->state); - - while (*state_ptr == state) - ; -} - -void barrier_wait(int root_idx, int num_ranks) -{ - // Phase 1: Wait for all threads to enter the barrier - auto shared = workspace[root_idx]; - sem_wait(&shared->mutex); - shared->counter++; - if (shared->counter == num_ranks) { - for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile1); } - } - sem_post(&shared->mutex); - sem_wait(&shared->turnstile1); - - // Phase 2: Wait for all threads to exit the barrier - sem_wait(&shared->mutex); - shared->counter--; - if (shared->counter == 0) { - for (int i = 0; i < num_ranks; ++i) { sem_post(&shared->turnstile2); } - } - sem_post(&shared->mutex); - sem_wait(&shared->turnstile2); -} - __m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); inline __m512 cvt_bf16_to_fp32(const __m256i src) { @@ -167,123 +146,53 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src) void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out) __attribute__((target("avx512bw"))); -void reduce_bf16_buffers(int start_elements, - int num_elements, - int num_buffers, - int to_buffer_idx, - struct allreduce_workspace** workspace) +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) __attribute__((target("avx512bw"))); void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out) __attribute__((target("avx512bw"))); -void reduce_fp32_buffers(int start_elements, - int num_elements, - int num_buffers, - int to_buffer_idx, - struct allreduce_workspace** workspace) +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) __attribute__((target("avx512bw"))); -// N_REDUCE_LIMIT is the number of buffers that can be reduced together in one shot. -// Compared with do N-1 2-reduces which needs 2*(N-1) read and N-1 write, -// N-reduce only needs N read and 1 write, this saves 2/3 memory bandwidth. -// When increase N_REDUCE_LIMIT to a bigger number, do the following steps -// 1. Extend REPEAT_ macros list down below -// 2. Extend switch cases which call "REPEAT(X, ...)" down below -#define N_REDUCE_LIMIT 16 - -void reduce_all_buffers(struct allreduce_workspace** workspace, - int start_elements, +void reduce_all_buffers(int start_elements, int num_elements, c10::ScalarType scalar_type, - int num_buffers, - int to_buffer_idx) + int to_buffer_idx, + char* to_buffer, + char** buffers) { switch (scalar_type) { case c10::ScalarType::BFloat16: - if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { - reduce_bf16_buffers( - start_elements, num_elements, num_buffers, to_buffer_idx, workspace); + if (world_size == 2) { + // add the other buffer to to_buffer + reduce_2_bf16_buffers_iio(num_elements, + buffers[1 - to_buffer_idx] + start_elements * 2, + to_buffer + start_elements * 2, + to_buffer + start_elements * 2); } else { - for (int i = 0; i < num_buffers; i++) { - if (i == to_buffer_idx) continue; - reduce_2_bf16_buffers_iio( - num_elements, - workspace[i]->buffer + start_elements * 2, - workspace[to_buffer_idx]->buffer + start_elements * 2, - workspace[to_buffer_idx]->buffer + start_elements * 2); - } + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); } break; case c10::ScalarType::Float: - if (num_buffers > 2 && num_buffers <= N_REDUCE_LIMIT) { - reduce_fp32_buffers( - start_elements, num_elements, num_buffers, to_buffer_idx, workspace); + if (world_size == 2) { + reduce_2_fp32_buffers_iio(num_elements, + buffers[1 - to_buffer_idx] + start_elements * 4, + to_buffer + start_elements * 4, + to_buffer + start_elements * 4); } else { - for (int i = 0; i < num_buffers; i++) { - if (i == to_buffer_idx) continue; - reduce_2_fp32_buffers_iio( - num_elements, - workspace[i]->buffer + start_elements * 4, - workspace[to_buffer_idx]->buffer + start_elements * 4, - workspace[to_buffer_idx]->buffer + start_elements * 4); - } + assert(world_size > 2); + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); } break; default: assert(!"Should not get here"); } } -#define REPEAT(N, x) REPEAT_##N(x) -#define REPEAT_1(x) x(1) -#define REPEAT_2(x) \ - REPEAT_1(x); \ - x(2) -#define REPEAT_3(x) \ - REPEAT_2(x); \ - x(3) -#define REPEAT_4(x) \ - REPEAT_3(x); \ - x(4) -#define REPEAT_5(x) \ - REPEAT_4(x); \ - x(5) -#define REPEAT_6(x) \ - REPEAT_5(x); \ - x(6) -#define REPEAT_7(x) \ - REPEAT_6(x); \ - x(7) -#define REPEAT_8(x) \ - REPEAT_7(x); \ - x(8) -#define REPEAT_9(x) \ - REPEAT_8(x); \ - x(9) -#define REPEAT_10(x) \ - REPEAT_9(x); \ - x(10) -#define REPEAT_11(x) \ - REPEAT_10(x); \ - x(11) -#define REPEAT_12(x) \ - REPEAT_11(x); \ - x(12) -#define REPEAT_13(x) \ - REPEAT_12(x); \ - x(13) -#define REPEAT_14(x) \ - REPEAT_13(x); \ - x(14) -#define REPEAT_15(x) \ - REPEAT_14(x); \ - x(15) - -#define CVT_ADD_BF16(x) \ - do { \ - auto in##x##_val = \ - cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[x]->buffer + i))); \ - inout_val = _mm512_add_ps(inout_val, in##x##_val); \ +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ } while (0) // Reduce functions down below use vectorized algorithm, the number of bytes processed each @@ -292,11 +201,7 @@ void reduce_all_buffers(struct allreduce_workspace** workspace, // whether this number needs to be changed #define VECTOR_LENGTH_IN_BYTES 32 -void reduce_bf16_buffers(int start_elements, - int num_elements, - int num_buffers, - int to_buffer_idx, - struct allreduce_workspace** workspace) +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) { const int element_size = 2; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; @@ -307,34 +212,40 @@ void reduce_bf16_buffers(int start_elements, #pragma omp parallel for for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0]->buffer + i))); - switch (num_buffers) { - case 16: REPEAT(15, CVT_ADD_BF16); break; - case 15: REPEAT(14, CVT_ADD_BF16); break; - case 14: REPEAT(13, CVT_ADD_BF16); break; - case 13: REPEAT(12, CVT_ADD_BF16); break; - case 12: REPEAT(11, CVT_ADD_BF16); break; - case 11: REPEAT(10, CVT_ADD_BF16); break; - case 10: REPEAT(9, CVT_ADD_BF16); break; - case 9: REPEAT(8, CVT_ADD_BF16); break; - case 8: REPEAT(7, CVT_ADD_BF16); break; - case 7: REPEAT(6, CVT_ADD_BF16); break; - case 6: REPEAT(5, CVT_ADD_BF16); break; - case 5: REPEAT(4, CVT_ADD_BF16); break; - case 4: REPEAT(3, CVT_ADD_BF16); break; - case 3: REPEAT(2, CVT_ADD_BF16); break; - default: assert(!"Should not get here."); + auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: CVT_ADD_BF16(15); + case 15: CVT_ADD_BF16(14); + case 14: CVT_ADD_BF16(13); + case 13: CVT_ADD_BF16(12); + case 12: CVT_ADD_BF16(11); + case 11: CVT_ADD_BF16(10); + case 10: CVT_ADD_BF16(9); + case 9: CVT_ADD_BF16(8); + case 8: CVT_ADD_BF16(7); + case 7: CVT_ADD_BF16(6); + case 6: CVT_ADD_BF16(5); + case 5: CVT_ADD_BF16(4); + case 4: CVT_ADD_BF16(3); + case 3: + CVT_ADD_BF16(2); + CVT_ADD_BF16(1); + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } } - _mm256_storeu_si256((__m256i*)(workspace[to_buffer_idx]->buffer + i), - cvt_fp32_to_bf16(inout_val)); + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); } // process remaining part int i = (start_elements + main_elements) * element_size; while (remain_elements > 0) { float val = 0.0f; - for (int j = 0; j < num_buffers; j++) { val += *(at::BFloat16*)(workspace[j]->buffer + i); } - *(at::BFloat16*)(workspace[to_buffer_idx]->buffer + i) = val; + for (int j = 0; j < world_size; j++) { val += *(at::BFloat16*)(buffers[j] + i); } + *(at::BFloat16*)(to_buffer + i) = val; remain_elements--; i += element_size; } @@ -367,17 +278,13 @@ void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out } } -#define CVT_ADD_F32(x) \ - do { \ - auto in##x##_val = _mm256_loadu_ps((float*)(workspace[x]->buffer + i)); \ - inout_val = _mm256_add_ps(inout_val, in##x##_val); \ +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ + inout_val = _mm256_add_ps(inout_val, in##x##_val); \ } while (0) -void reduce_fp32_buffers(int start_elements, - int num_elements, - int num_buffers, - int to_buffer_idx, - struct allreduce_workspace** workspace) +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) { const int element_size = 4; const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; @@ -388,33 +295,40 @@ void reduce_fp32_buffers(int start_elements, #pragma omp parallel for for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; i += VECTOR_LENGTH_IN_BYTES) { - auto inout_val = _mm256_loadu_ps((float*)(workspace[0]->buffer + i)); - switch (num_buffers) { - case 16: REPEAT(15, CVT_ADD_F32); break; - case 15: REPEAT(14, CVT_ADD_F32); break; - case 14: REPEAT(13, CVT_ADD_F32); break; - case 13: REPEAT(12, CVT_ADD_F32); break; - case 12: REPEAT(11, CVT_ADD_F32); break; - case 11: REPEAT(10, CVT_ADD_F32); break; - case 10: REPEAT(9, CVT_ADD_F32); break; - case 9: REPEAT(8, CVT_ADD_F32); break; - case 8: REPEAT(7, CVT_ADD_F32); break; - case 7: REPEAT(6, CVT_ADD_F32); break; - case 6: REPEAT(5, CVT_ADD_F32); break; - case 5: REPEAT(4, CVT_ADD_F32); break; - case 4: REPEAT(3, CVT_ADD_F32); break; - case 3: REPEAT(2, CVT_ADD_F32); break; - default: assert(!"Should not get here."); + auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); + switch (world_size) { + case 16: CVT_ADD_F32(15); + case 15: CVT_ADD_F32(14); + case 14: CVT_ADD_F32(13); + case 13: CVT_ADD_F32(12); + case 12: CVT_ADD_F32(11); + case 11: CVT_ADD_F32(10); + case 10: CVT_ADD_F32(9); + case 9: CVT_ADD_F32(8); + case 8: CVT_ADD_F32(7); + case 7: CVT_ADD_F32(6); + case 6: CVT_ADD_F32(5); + case 5: CVT_ADD_F32(4); + case 4: CVT_ADD_F32(3); + case 3: + CVT_ADD_F32(2); + CVT_ADD_F32(1); + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); + inout_val = _mm256_add_ps(inout_val, in_val); + } } - _mm256_storeu_ps((float*)(workspace[to_buffer_idx]->buffer + i), inout_val); + _mm256_storeu_ps((float*)(to_buffer + i), inout_val); } // process remaining part int i = (start_elements + main_elements) * element_size; while (remain_elements > 0) { float val = 0.0f; - for (int j = 0; j < num_buffers; j++) { val += *(float*)(workspace[j]->buffer + i); } - *(float*)(workspace[to_buffer_idx]->buffer + i) = val; + for (int j = 0; j < world_size; j++) { val += *(float*)(buffers[j] + i); } + *(float*)(to_buffer + i) = val; remain_elements--; i += element_size; } @@ -448,7 +362,6 @@ void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out } static bool is_initialized = 0; -static int world_size; static int world_rank; void shm_initialize(int size, int rank, char* addr_string, char* port_string) @@ -477,10 +390,15 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; - workspace_buf->state = coll_begin; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; // create the workspace pointer list workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*)); + symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); // map shm of all ranks for (int i = 0; i < size; i++) { @@ -494,11 +412,11 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) workspace[i] = workspace_buf_other; } else { workspace[i] = workspace_buf; - workspace_buf->counter = 0; - sem_init(&workspace_buf->mutex, 1, 1); - sem_init(&workspace_buf->turnstile1, 1, 0); - sem_init(&workspace_buf->turnstile2, 1, 0); } + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); } } @@ -539,46 +457,122 @@ size_t slice_el_start(size_t chunk_el, int slice_idx) return slice_size * slice_idx; } -void naive_all_reduce(char* data_ptr, - c10::ScalarType scalar_type, - size_t chunk_size, - size_t chunk_el) +/* + Symmetrical naive all_reduce + step 0: before enter the function ith times, state is copy(i-1) + step 1: each rank copy data from input (data_ptr) to SHM buffer[i] + step 2: set own state to copy(i) + step 3: wait each other rank's state equal or later than copy(i) + step 4: reduce across SHM buffer(ith) directly into output (data_ptr) +*/ +void symmetric_naive_all_reduce(char* data_ptr, + c10::ScalarType scalar_type, + size_t chunk_size, + size_t chunk_el) { - parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; +#ifdef DO_PROFILE + static double total_t1_t0 = 0.0; + static double total_t2_t1 = 0.0; + static double total_t3_t2 = 0.0; + static int count = -16; // warmup + auto t0 = std::chrono::system_clock::now(); +#endif - if (world_rank == 0) { - // compute allreduce result on rank 0 - for (int i = 1; i < world_size; i++) { - // wait until the other rank copy the buffer - wait_buffer_state_until(i, coll_allreduce_naive__copy_in_done); - } - reduce_all_buffers(workspace, 0, chunk_el, scalar_type, world_size, 0); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__reduce_done; - parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size); + /* + We can't have infinite number of buffers and states. 2 sets of buffer + and 3 sets of states is just enough. Consider current rank is in step 3, + with it's own state set to copy(i), the other rank will them have the + following situations: + ------------------------------------------------ + my state | can I proceed? | the other rank state + ================================================ + | N | copy(i-1) + |----------------|--------------------- + copy(i) | Y | copy(i) + |----------------|--------------------- + | Y | copy(i+1) + ------------------------------------------------ + * When I have state as copy(i), the other rank cannot have state + copy(i-2) or before. In that case I'll be in state copy(i-1) and cannot + proceed to copy(i). + * The other rank cannot have state copy(i+2) or beyond because my + state is still copy(i), copy(i+1) is as far as the other rank could go. + * From a rank's POV, all the other ranks can be divided into three sets: + - Lagging ranks: ranks that are still working on previous iteration + - Syncing ranks: ranks that are working on current iteration + - Leading ranks: ranks that are working on next iteration + * We can have 3 sets of states, one set for syncing ranks; one set for + lagging ranks; one set of leading ranks. With 3 sets of states, we can + distinguish between lagging and leading ranks. + * Note from any rank's POV, leading ranks and lagging ranks does not + appear at the same time. Either all other ranks are syncing or + lagging, or all other ranks are syncing or leading. Otherwise leading + and lagging ranks will be 2 iterations apart and this should not happen. + * So we have 2 sets of buffers, one buffer is used by current iter; + one buffer used by either lagging ranks or leading ranks. + */ + const int state_group = 0; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + copy_next = coll_alt2_allreduce_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allreduce_naive__copy_in_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: assert(!"Should not get here."); } - if (world_rank != 0) { - wait_buffer_state_until(0, coll_allreduce_naive__reduce_done); - parallel_memcpy(data_ptr, workspace[0]->buffer, chunk_size); - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__copy_out_done; + state_idx = (state_idx + 1) % 3; + + parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + +#ifdef DO_PROFILE + auto t1 = std::chrono::system_clock::now(); +#endif + + for (int i = 0; i < world_size; i++) { + // wait until the other rank copy the buffer + if (i != world_rank) { wait_buffer_state_until_2(i, copy_current, copy_next, state_group); } } - if (world_rank == 0) { - for (int i = 1; i < world_size; i++) { - wait_buffer_state_until(i, coll_allreduce_naive__copy_out_done); +#ifdef DO_PROFILE + auto t2 = std::chrono::system_clock::now(); +#endif + + // each rank reduce the buffer independently so therre is no need for synchronization afterward + reduce_all_buffers( + 0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]); + + // switch buffer + current_buffer = 1 - current_buffer; + +#ifdef DO_PROFILE + auto t3 = std::chrono::system_clock::now(); + + count++; + if (count > 0) { + total_t1_t0 += std::chrono::duration_cast(t1 - t0).count(); + total_t2_t1 += std::chrono::duration_cast(t2 - t1).count(); + total_t3_t2 += std::chrono::duration_cast(t3 - t2).count(); + if (world_rank == 0 && count == 1000) { + printf("symmetric_naive_all_reduce time breakdown:\n"); + printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); + printf("\twait for copy: %.2f\n", total_t2_t1 / count); + printf("\treduce: %.2f\n", total_t3_t2 / count); } - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_begin; - } - if (world_rank != 0) { - // if rank 0 spin too fast it could be in state 1 of next allreduce - // in this case wait_buffer_state_until(0, 0) may cause deadlock - // what we are certain is when rank 0 finishes the state won't be 2 - wait_buffer_state_until_not(0, coll_allreduce_naive__reduce_done); - workspace[world_rank]->state = coll_begin; } +#endif } // naive allreduce distributed, each rank do naive reduce on its slice @@ -597,10 +591,33 @@ void distributed_naive_reduce(char* data_ptr, auto t0 = std::chrono::system_clock::now(); #endif + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next, reduce_current; + + // similar to symmetric_naive_allreduce, but here we only need two sets of + // states, because distributed naive reduce has two barriers in the algorithm + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + reduce_current = coll_allreduce_naive__reduce_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + reduce_current = coll_alt1_allreduce_naive__reduce_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 2; + int data_size = chunk_size / chunk_el; - parallel_memcpy(workspace[world_rank]->buffer, data_ptr, chunk_size); + parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__copy_in_done; + workspace[world_rank]->states[state_group] = copy_current; #ifdef DO_PROFILE auto t1 = std::chrono::system_clock::now(); @@ -608,7 +625,8 @@ void distributed_naive_reduce(char* data_ptr, for (int i = 0; i < world_size; i++) { // wait until all the other ranks copy the buffer - wait_buffer_state_until_range(i, coll_allreduce_naive__copy_in_done, 2); + if (i != world_rank) + wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); } #ifdef DO_PROFILE @@ -616,40 +634,36 @@ void distributed_naive_reduce(char* data_ptr, #endif // reduce scatter - reduce_all_buffers(workspace, - slice_el_start(chunk_el, world_rank), + reduce_all_buffers(slice_el_start(chunk_el, world_rank), slice_size(chunk_el, world_rank), scalar_type, - world_size, - world_rank); + world_rank, + distributed_buffer[current_buffer][world_rank], + distributed_buffer[current_buffer]); std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__reduce_done; + workspace[world_rank]->states[state_group] = reduce_current; #ifdef DO_PROFILE auto t3 = std::chrono::system_clock::now(); #endif for (int i = 0; i < world_size; i++) { - int rank = (i + world_rank) % world_size; - // wait until the other rank reduce the buffer - wait_buffer_state_until_range(rank, coll_allreduce_naive__reduce_done, 2); - parallel_memcpy(slice_data(data_ptr, chunk_el, data_size, rank), - slice_data(workspace[rank]->buffer, chunk_el, chunk_size / chunk_el, rank), - slice_size(chunk_el, rank) * data_size); + // wait until all the other ranks reduce the buffer + if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); } - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_allreduce_naive__copy_out_done; -#ifdef DO_PROFILE auto t4 = std::chrono::system_clock::now(); -#endif for (int i = 0; i < world_size; i++) { - wait_buffer_state_until_not(i, coll_allreduce_naive__reduce_done); + int rank = (i + world_rank) % world_size; + parallel_memcpy( + slice_data(data_ptr, chunk_el, data_size, rank), + slice_data( + distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank), + slice_size(chunk_el, rank) * data_size); } - std::atomic_thread_fence(std::memory_order_release); - workspace[world_rank]->state = coll_begin; + current_buffer = 1 - current_buffer; #ifdef DO_PROFILE auto t5 = std::chrono::system_clock::now(); @@ -665,8 +679,8 @@ void distributed_naive_reduce(char* data_ptr, printf("\tcopy input buffer: %.2f\n", total_t1_t0 / count); printf("\twait for copy: %.2f\n", total_t2_t1 / count); printf("\treduce: %.2f\n", total_t3_t2 / count); - printf("\tcopy buffer to output: %.2f\n", total_t4_t3 / count); - printf("\twait finish: %.2f\n", total_t5_t4 / count); + printf("\twait for reduce finish: %.2f\n", total_t4_t3 / count); + printf("\tcopy out: %.2f\n", total_t5_t4 / count); } } #endif @@ -679,7 +693,7 @@ void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; size_t chunk_el = chunk_size / (data_size / numel); if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) - naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); else distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); }