From 162def8f18b07f9c73cea6f88e2dcb78c86af8e3 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Thu, 13 Jun 2024 12:14:07 -0700 Subject: [PATCH] TL/MLX5: add device mem mcast bcast --- src/components/mc/ucc_mc.c | 1 + src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 4 + .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 17 ++++ .../tl/mlx5/mcast/tl_mlx5_mcast_coll.h | 2 + .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 10 ++- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 9 +- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 15 +++- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 85 +++++++++++++------ src/components/tl/mlx5/tl_mlx5.c | 4 + src/components/tl/mlx5/tl_mlx5_coll.c | 5 ++ 10 files changed, 121 insertions(+), 31 deletions(-) diff --git a/src/components/mc/ucc_mc.c b/src/components/mc/ucc_mc.c index 997355443e..4b17d4a4d4 100644 --- a/src/components/mc/ucc_mc.c +++ b/src/components/mc/ucc_mc.c @@ -132,6 +132,7 @@ ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type) return mc->get_attr(attr); } +/* TODO: add the flexbility to bypass the mpool if the user asks for it */ UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_alloc, (h_ptr, size, mem_type), ucc_mc_buffer_header_t **h_ptr, size_t size, ucc_memory_type_t mem_type) diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 1208226bda..4e1caebf2a 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -17,6 +17,7 @@ #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" #include "core/ucc_service_coll.h" +#include "components/mc/ucc_mc.h" #define POLL_PACKED 16 #define REL_DONE ((void*)-1) @@ -98,6 +99,7 @@ typedef struct mcast_coll_comm_init_spec { int scq_moderation; int wsize; int max_eager; + int cuda_mem_enabled; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -261,6 +263,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int pending_recv; struct ibv_mr *pp_mr; char *pp_buf; + ucc_mc_buffer_header_t *pp_buf_header; struct pp_packet *pp; uint32_t psn; uint32_t last_psn; @@ -293,6 +296,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int n_prep_reliable; int n_mcast_reliable; int wsize; + int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; ucc_tl_mlx5_mcast_service_coll_t service_coll; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 9696ba8c82..a6725698ce 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) } } +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm; + ucc_coll_args_t *args = &coll_args->args; + + if ((comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) || + (!comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) { + return UCC_OK; + } + + return UCC_ERR_NO_RESOURCE; +} + ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index 74385b1573..a5725915f7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index f57daeab5e..132cd8df78 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -300,7 +300,12 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_ERR_NO_RESOURCE; } - comm->max_inline = qp_init_attr.cap.max_inline_data; + if (comm->cuda_mem_enabled) { + /* max inline send otherwise it segfault during ibv send */ + comm->max_inline = 0; + } else { + comm->max_inline = qp_init_attr.cap.max_inline_data; + } return UCC_OK; } @@ -609,6 +614,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return UCC_ERR_NO_RESOURCE; } } + if (comm->grh_buf) { ucc_free(comm->grh_buf); } @@ -626,7 +632,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) } if (comm->pp_buf) { - ucc_free(comm->pp_buf); + ucc_mc_free(comm->pp_buf_header); } if (comm->call_rwr) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 9d66f3453e..19dfd88097 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -47,6 +47,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t int rc; int length; ucc_status_t status; + ucc_memory_type_t mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA + : UCC_MEMORY_TYPE_HOST; for (i = 0; i < num_packets; i++) { if (comm->params.sx_depth <= @@ -75,7 +77,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t if (zcopy) { pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset); } else { - memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length); + status = ucc_mc_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length, + mem_type, mem_type); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy cuda buffer"); + return status; + } ssg[0].addr = (uint64_t) pp->buf; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 4522097973..f506137f3d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -371,6 +371,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com { ucc_status_t status = UCC_OK; void *dest; + ucc_memory_type_t mem_type; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -379,7 +380,19 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - memcpy(dest, (void*) pp->buf, pp->length); + + if (comm->cuda_mem_enabled) { + mem_type = UCC_MEMORY_TYPE_CUDA; + } else { + mem_type = UCC_MEMORY_TYPE_HOST; + } + + status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, + mem_type, mem_type); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy buffer"); + return status; + } } comm->r_window[pp->psn & (comm->wsize-1)] = pp; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 402ff84472..52bc242ae5 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -13,6 +13,17 @@ #include "mcast/tl_mlx5_mcast_helper.h" #include "mcast/tl_mlx5_mcast_service_coll.h" +static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver() +{ + const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version"; + + if (!access(file, F_OK)) { + return UCC_OK; + } + + return UCC_ERR_NO_RESOURCE; +} + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, @@ -88,23 +99,14 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); - comm->wsize = conf_params->wsize; - comm->max_eager = conf_params->max_eager; - comm->comm_id = team_params->id; - comm->ctx = mcast_context; - comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf"); - if (!comm->grh_buf) { - status = UCC_ERR_NO_MEMORY; - goto cleanup; - } + comm->wsize = conf_params->wsize; + comm->max_eager = conf_params->max_eager; + comm->cuda_mem_enabled = conf_params->cuda_mem_enabled; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; - memset(comm->grh_buf, 0, GRH_LENGTH); - - comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH, - IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_LOCAL_WRITE); - if (!comm->grh_mr) { - tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno); + if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) { + tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready"); status = UCC_ERR_NO_RESOURCE; goto cleanup; } @@ -162,9 +164,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm) { - ucc_status_t status; - size_t page_size; - int buf_size, i, ret; + ucc_status_t status; + size_t page_size; + int buf_size, i, ret; + ucc_memory_type_t supported_mem_type; status = ucc_tl_mlx5_mcast_init_qps(comm->ctx, comm); if (UCC_OK != status) { @@ -197,19 +200,47 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ comm->pending_recv = 0; comm->buf_n = comm->params.rx_depth * 2; - ret = posix_memalign((void**) &comm->pp_buf, page_size, buf_size * comm->buf_n); - if (ret) { - tl_error(comm->ctx->lib, "posix_memalign failed"); - return UCC_ERR_NO_MEMORY; + supported_mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA + : UCC_MEMORY_TYPE_HOST; + + comm->grh_buf = ucc_malloc(GRH_LENGTH * sizeof(char), "grh"); + if (ucc_unlikely(!comm->grh_buf)) { + tl_error(comm->ctx->lib, "failed to allocate grh memory"); + return status; + } + + status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_HOST); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not cuda memset"); + goto error; + } + + comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH, + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); + if (!comm->grh_mr) { + tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + status = ucc_mc_alloc(&comm->pp_buf_header, buf_size * comm->buf_n, supported_mem_type); + comm->pp_buf = comm->pp_buf_header->addr; + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->ctx->lib, "failed to allocate cuda memory"); + goto error; + } + + status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, supported_mem_type); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not memset"); + goto error; } - memset(comm->pp_buf, 0, buf_size * comm->buf_n); - comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n, IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); if (!comm->pp_mr) { - tl_error(comm->ctx->lib, "could not register pp_buf mr, errno %d", errno); - status = UCC_ERR_NO_MEMORY; + tl_error(comm->ctx->lib, "could not register pp_buf device mr, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; goto error; } diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 75e6f517cc..7bf8572aab 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -92,6 +92,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager), UCC_CONFIG_TYPE_MEMUNITS}, + {"MCAST_CUDA_MEM_ENABLE", "0", "Enable GPU CUDA memory support for Mcast. GPUDirect RDMA must be enabled", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index e918be166e..a8add9715e 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -19,6 +19,11 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, tl_trace(team->context->lib, "mcast bcast not supported for active sets"); return UCC_ERR_NOT_SUPPORTED; } + + if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { + tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); + return UCC_ERR_NOT_SUPPORTED; + } task = ucc_tl_mlx5_get_task(coll_args, team);