From 5fca2e865250737a690eb69ee847aea88b4ada7a Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Mon, 17 Jun 2024 12:57:25 -0700 Subject: [PATCH] TL/MLX5: addressing sam's comments on PR 989 --- src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 8 ++- .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 17 ++++++ .../tl/mlx5/mcast/tl_mlx5_mcast_coll.h | 2 + .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 10 ++-- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 10 +++- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 9 ++- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 58 +++++++++++++------ src/components/tl/mlx5/tl_mlx5.c | 4 +- src/components/tl/mlx5/tl_mlx5_coll.c | 5 ++ 9 files changed, 91 insertions(+), 32 deletions(-) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 264bb74ff9..708ff08f5e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -17,7 +17,7 @@ #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" #include "core/ucc_service_coll.h" -#include "utils/arch/cuda_def.h" +#include "components/mc/ucc_mc.h" #define POLL_PACKED 16 #define REL_DONE ((void*)-1) @@ -91,7 +91,7 @@ typedef struct mcast_coll_comm_init_spec { int scq_moderation; int wsize; int max_eager; - int device_mem_enabled; + int cuda_mem_enabled; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -196,17 +196,19 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_rank_t rank; ucc_rank_t commsize; char *grh_buf; + ucc_mc_buffer_header_t *grh_cuda_header; struct ibv_mr *grh_mr; uint16_t mcast_lid; union ibv_gid mgid; unsigned max_inline; size_t max_eager; - int device_mem_enabled; + int cuda_mem_enabled; int max_per_packet; int pending_send; int pending_recv; struct ibv_mr *pp_mr; char *pp_buf; + ucc_mc_buffer_header_t *pp_cuda_header; struct pp_packet *pp; uint32_t psn; uint32_t last_psn; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 9696ba8c82..a6725698ce 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) } } +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm; + ucc_coll_args_t *args = &coll_args->args; + + if ((comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) || + (!comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) { + return UCC_OK; + } + + return UCC_ERR_NO_RESOURCE; +} + ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index 74385b1573..a5725915f7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index 608b59d7cd..0fe8590dfb 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -300,7 +300,7 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_ERR_NO_RESOURCE; } - if (comm->device_mem_enabled) { + if (comm->cuda_mem_enabled) { /* max inline send otherwise it segfault during ibv send */ comm->max_inline = 0; } else { @@ -482,8 +482,8 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) } if (comm->grh_buf) { - if (comm->device_mem_enabled) { - cudaFree(comm->grh_buf); + if (comm->cuda_mem_enabled) { + ucc_mc_free(comm->grh_cuda_header); } else { ucc_free(comm->grh_buf); } @@ -502,8 +502,8 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) } if (comm->pp_buf) { - if (comm->device_mem_enabled) { - cudaFree(comm->pp_buf); + if (comm->cuda_mem_enabled) { + ucc_mc_free(comm->pp_cuda_header); } else { ucc_free(comm->pp_buf); } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index f45107ad56..32fa4959d5 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -75,9 +75,13 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t if (zcopy) { pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset); } else { - if (comm->device_mem_enabled) { - CUDA_FUNC(cudaMemcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), - length, cudaMemcpyDeviceToDevice)); + if (comm->cuda_mem_enabled) { + status = ucc_mc_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length, + UCC_MEMORY_TYPE_CUDA, UCC_MEMORY_TYPE_CUDA); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy cuda buffer"); + return status; + } } else { memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length); } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 40b754ecc3..341433b41e 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -379,8 +379,13 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - if (comm->device_mem_enabled) { - cudaMemcpy(dest, (void*) pp->buf, pp->length, cudaMemcpyDeviceToDevice); + if (comm->cuda_mem_enabled) { + status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, + UCC_MEMORY_TYPE_CUDA, UCC_MEMORY_TYPE_CUDA); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy cuda buffer"); + return status; + } } else { memcpy(dest, (void*) pp->buf, pp->length); } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 4f53ba50d5..e42d980e77 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -12,6 +12,18 @@ #include "p2p/ucc_tl_mlx5_mcast_p2p.h" #include "mcast/tl_mlx5_mcast_helper.h" +static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver() +{ + ucc_status_t status = UCC_ERR_NO_RESOURCE; + const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version"; + + if (!access(file, F_OK)) { + status = UCC_OK; + } + + return status; +} + static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root, ucc_service_coll_req_t **bcast_req) { @@ -120,10 +132,16 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, comm->wsize = conf_params->wsize; comm->max_eager = conf_params->max_eager; - comm->device_mem_enabled = conf_params->device_mem_enabled; + comm->cuda_mem_enabled = conf_params->cuda_mem_enabled; comm->comm_id = team_params->id; comm->ctx = mcast_context; + if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) { + tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready"); + status = UCC_ERR_NO_RESOURCE; + goto cleanup; + } + comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); if (!comm->rcq) { ibv_dereg_mr(comm->grh_mr); @@ -212,35 +230,41 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ comm->pending_recv = 0; comm->buf_n = comm->params.rx_depth * 2; - if (comm->device_mem_enabled) { - /* TODO add check to make sure GPUDirect is enabled - * lsmod | grep nv_peer */ - CUDA_FUNC(cudaMalloc((void **)&comm->grh_buf, GRH_LENGTH * sizeof(char))); - if (!comm->grh_buf) { - tl_error(comm->ctx->lib, "cuda memcpy failed"); - status = UCC_ERR_NO_MEMORY; - goto error; + if (comm->cuda_mem_enabled) { + status = ucc_mc_alloc(&comm->grh_cuda_header, GRH_LENGTH * + sizeof(char), UCC_MEMORY_TYPE_CUDA); + comm->grh_buf = comm->grh_cuda_header->addr; + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->ctx->lib, "failed to allocate cuda memory"); + return status; } - CUDA_FUNC(cudaMemset(comm->grh_buf, 0, GRH_LENGTH)); + status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_CUDA); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not cuda memset"); + goto error; + } comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH, IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); if (!comm->grh_mr) { - tl_error(comm->ctx->lib, "Could not register device memory for GRH, errno %d", errno); + tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno); status = UCC_ERR_NO_RESOURCE; goto error; } - // assuming the device page size is same as host page size - CUDA_FUNC(cudaMalloc((void**) &comm->pp_buf, buf_size * comm->buf_n)); - if (!comm->pp_buf) { - tl_error(comm->ctx->lib, "cuda memcpy failed"); - status = UCC_ERR_NO_MEMORY; + status = ucc_mc_alloc(&comm->pp_cuda_header, buf_size * comm->buf_n, UCC_MEMORY_TYPE_CUDA); + comm->pp_buf = comm->pp_cuda_header->addr; + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->ctx->lib, "failed to allocate cuda memory"); goto error; } - CUDA_FUNC(cudaMemset(comm->pp_buf, 0, buf_size * comm->buf_n)); + status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, UCC_MEMORY_TYPE_CUDA); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not cuda memset"); + goto error; + } comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n, IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE | diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 864865dd41..38ac2e4ca0 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -92,8 +92,8 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager), UCC_CONFIG_TYPE_MEMUNITS}, - {"MCAST_DEVICE_MEM_ENABLE", "0", "Enable GPU memory support for Mcast", - ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.device_mem_enabled), + {"MCAST_CUDA_MEM_ENABLE", "0", "Enable GPU CUDA memory support for Mcast. GPUDirect RDMA must be enabled", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled), UCC_CONFIG_TYPE_INT}, {NULL}}; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index e918be166e..a8add9715e 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -19,6 +19,11 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, tl_trace(team->context->lib, "mcast bcast not supported for active sets"); return UCC_ERR_NOT_SUPPORTED; } + + if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { + tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); + return UCC_ERR_NOT_SUPPORTED; + } task = ucc_tl_mlx5_get_task(coll_args, team);