Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

group launch allreduce + allgather #1566

Open
alpha-baby opened this issue Jan 7, 2025 · 1 comment
Open

group launch allreduce + allgather #1566

alpha-baby opened this issue Jan 7, 2025 · 1 comment

Comments

@alpha-baby
Copy link

cuda version:

Tue Jan  7 14:20:53 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06             Driver Version: 535.183.06   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H20                     On  | 00000000:08:00.0 Off |                    0 |
| N/A   33C    P0              71W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H20                     On  | 00000000:7E:00.0 Off |                    0 |
| N/A   32C    P0              72W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H20                     On  | 00000000:A2:00.0 Off |                    0 |
| N/A   35C    P0              73W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H20                     On  | 00000000:C6:00.0 Off |                    0 |
| N/A   33C    P0              72W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H20                     On  | 00000001:09:00.0 Off |                    0 |
| N/A   31C    P0              74W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H20                     On  | 00000001:7F:00.0 Off |                    0 |
| N/A   32C    P0              73W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H20                     On  | 00000001:A3:00.0 Off |                    0 |
| N/A   33C    P0              74W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H20                     On  | 00000001:C7:00.0 Off |                    0 |
| N/A   35C    P0              72W / 500W |      0MiB / 97871MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

nccl version: 2.21.5

demo code:

#include "cuda_runtime.h"
#include <cstdarg>
#include "mpi.h"
#include "nccl.h"
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/time.h>
#include <time.h>
#include <ctime>

int glocal_rank=-1;

#define MPICHECK(cmd)                                                          \
  do {                                                                         \
    int e = cmd;                                                               \
    if (e != MPI_SUCCESS) {                                                    \
      printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e);         \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

#define CUDACHECK(cmd)                                                         \
  do {                                                                         \
    cudaError_t e = cmd;                                                       \
    if (e != cudaSuccess) {                                                    \
      printf("rank %d Failed: Cuda error %s:%d '%s'\n", glocal_rank, __FILE__, __LINE__,            \
             cudaGetErrorString(e));                                           \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

#define NCCLCHECK(cmd)                                                         \
  do {                                                                         \
    ncclResult_t r = cmd;                                                      \
    if (r != ncclSuccess) {                                                    \
      printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__,            \
             ncclGetErrorString(r));                                           \
      exit(EXIT_FAILURE);                                                      \
    }                                                                          \
  } while (0)

static uint64_t getHostHash(const char *string) {
  // Based on DJB2a, result = result * 33 ^ char
  uint64_t result = 5381;
  for (int c = 0; string[c] != '\0'; c++) {
    result = ((result << 5) + result) ^ string[c];
  }
  return result;
}

static void getHostName(char *hostname, int maxlen) {
  gethostname(hostname, maxlen);
  for (int i = 0; i < maxlen; i++) {
    if (hostname[i] == '.') {
      hostname[i] = '\0';
      return;
    }
  }
}

void log_print(const char *format, ...) {
  	va_list args;
	va_start(args, format);
  char timestampBuffer[27] = "";
  struct timeval tv;
  gettimeofday(&tv, NULL);
  std::tm timeinfo;
  localtime_r(&tv.tv_sec, &timeinfo);
  snprintf(timestampBuffer, sizeof(timestampBuffer),
           "[%04d-%02d-%02dT%02d:%02d:%02d.%03ld] ", timeinfo.tm_year + 1900,
           timeinfo.tm_mon + 1, timeinfo.tm_mday, timeinfo.tm_hour,
           timeinfo.tm_min, timeinfo.tm_sec, tv.tv_usec / 1000);
  printf("%s", timestampBuffer);
  vprintf(format, args);
  va_end(args);
}

// 计算两个 timeval 结构之间的时间差,返回微秒数
long get_time_diff_microseconds(struct timeval start, struct timeval end) {
    long seconds = end.tv_sec - start.tv_sec;
    long microseconds = end.tv_usec - start.tv_usec;
    return seconds * 1000000 + microseconds;
}

int main(int argc, char *argv[]) {
  int size = 60 * 1024 * 1024;
  int runCount = 1;
  struct timeval start, end;
  int myRank, nRanks, localRank = 0;

  // initializing MPI
  MPICHECK(MPI_Init(&argc, &argv));
  MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
  MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));

  // calculating localRank based on hostname which is used in selecting a GPU
  uint64_t hostHashs[nRanks];
  char hostname[1024];
  getHostName(hostname, 1024);
  hostHashs[myRank] = getHostHash(hostname);
  MPICHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs,
                         sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
  for (int p = 0; p < nRanks; p++) {
    if (p == myRank)
      break;
    if (hostHashs[p] == hostHashs[myRank])
      localRank++;
  }

  cudaStream_t s;
  ncclUniqueId id;
  ncclComm_t comm;
  float *sendbuff, *recvbuff;
  float *sendbuff2, *recvbuff2;
  glocal_rank = myRank;
  // get NCCL unique ID at rank 0 and broadcast it to all others
  if (myRank == 0)
    ncclGetUniqueId(&id);
  MPICHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));

  printf("set cuda device: %d\n",localRank);
  // picking a GPU based on localRank, allocate device buffers
  CUDACHECK(cudaSetDevice(localRank));
  CUDACHECK(cudaMalloc(&sendbuff, size * sizeof(float)));
  CUDACHECK(cudaMalloc(&recvbuff, size * sizeof(float)));
  CUDACHECK(cudaMalloc(&recvbuff2, size * sizeof(float)));
  CUDACHECK(cudaMalloc(&sendbuff2, size * sizeof(float)));
  CUDACHECK(cudaStreamCreate(&s));

  // initializing NCCL
  NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
  size_t rankOffset = size * 4;
  int allgatherSendCount = size/nRanks;
  auto call_func = [&](int count) {
    for (int i = 0; i < count; i++) {
        NCCLCHECK(ncclGroupStart());
        NCCLCHECK(ncclAllReduce((const void *)sendbuff, (void *)recvbuff, size,
                             ncclFloat32, ncclSum, comm, s));
        if (allgatherSendCount > 0) {
        NCCLCHECK(ncclAllGather((const void *)(sendbuff2+size*allgatherSendCount), (void *)recvbuff2, allgatherSendCount,
                             ncclFloat32, comm, s));
        }
        NCCLCHECK(ncclGroupEnd());
    }
      // completing NCCL operation by synchronizing on the CUDA stream
      CUDACHECK(cudaStreamSynchronize(s));
  };
  // warn
  call_func(5);

  // 获取开始时间, get begin time
  if (gettimeofday(&start, NULL) != 0) {
      perror("gettimeofday");
      return 1;
  }
  runCount = 10000;
  call_func(runCount);
  // 获取结束时间, get end time
  if (gettimeofday(&end, NULL) != 0) {
      perror("gettimeofday");
      return 1;
  }
  // 计算时间差
  long elapsed_microseconds = get_time_diff_microseconds(start, end);

  printf("rank: %d run duration: %lf 秒\n", myRank, (double)elapsed_microseconds/(double)1000000);
  // free device buffers
  CUDACHECK(cudaFree(sendbuff));
  CUDACHECK(cudaFree(recvbuff));

  // finalizing NCCL
  ncclCommDestroy(comm);

  // finalizing MPI
  MPICHECK(MPI_Finalize());

  printf("[MPI Rank %d] runCount: %d Success \n",  myRank, runCount);
  return 0;
}

run with nccl version 2.21.5 , output duration:

set cuda device: 0
set cuda device: 1
set cuda device: 2
set cuda device: 3
set cuda device: 4
set cuda device: 5
set cuda device: 6
set cuda device: 7
rank: 0 run duration: 40.983264 秒
rank: 1 run duration: 40.983236 秒
rank: 2 run duration: 40.983254 秒
rank: 3 run duration: 40.983243 秒
rank: 4 run duration: 40.983244 秒
rank: 5 run duration: 40.983269 秒
rank: 7 run duration: 40.983285 秒
rank: 6 run duration: 40.983284 秒
[MPI Rank 5] runCount: 10000 Success 
[MPI Rank 1] runCount: 10000 Success 
[MPI Rank 3] runCount: 10000 Success 
[MPI Rank 7] runCount: 10000 Success 
[MPI Rank 6] runCount: 10000 Success 
[MPI Rank 0] runCount: 10000 Success 
[MPI Rank 2] runCount: 10000 Success 
[MPI Rank 4] runCount: 10000 Success 

run with nccl version 2.18.3 , output duration:

set cuda device: 0
set cuda device: 1
set cuda device: 2
set cuda device: 4
set cuda device: 5
set cuda device: 6
set cuda device: 7
set cuda device: 3
rank: 0 run duration: 17.108217 秒
rank: 1 run duration: 17.108238 秒
rank: 7 run duration: 17.108258 秒
rank: 2 run duration: 17.108236 秒
rank: 6 run duration: 17.108263 秒
rank: 3 run duration: 17.108257 秒
rank: 4 run duration: 17.108248 秒
rank: 5 run duration: 17.108247 秒
[MPI Rank 1] runCount: 10000 Success 
[MPI Rank 7] runCount: 10000 Success 
[MPI Rank 3] runCount: 10000 Success 
[MPI Rank 5] runCount: 10000 Success 
[MPI Rank 6] runCount: 10000 Success 
[MPI Rank 4] runCount: 10000 Success 
[MPI Rank 0] runCount: 10000 Success 
[MPI Rank 2] runCount: 10000 Success

Why does running on version 2.18.3 yield better performance? I also tested that running allreduce alone in version 2.21.5 takes about 10 seconds, and running allgather alone takes about 6 seconds.

@kiskra-nvidia
Copy link
Member

Just to make sure that I understood: you are saying that grouping allreduce and allgather results in a considerable slowdown with NCCL 2.21.5 compared to NCCL 2.18.3, correct? And that if you run each collective separately, the performance is as expected (10+6=16 -- more-or-less in-line with the 17s seen with 2.18.3)?

I don't have a ready answer for you but I would recommend that you try with the most recent NCCL release (2.24.3 was released today!). If you continue seeing this issue with the current release, we'll be happy to investigate. For that, we would probably want to start by comparing the logs from different versions, obtained with NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=INIT,ENV,GRAPH,TUNING (no need to run for 10000 iterations in that case -- you could actually reduce runCount to 0 as the warm-up iterations should tell us all we would need to know).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants