diff --git a/src/components/ec/base/ucc_ec_base.h b/src/components/ec/base/ucc_ec_base.h index 52a2427318..da76d61140 100644 --- a/src/components/ec/base/ucc_ec_base.h +++ b/src/components/ec/base/ucc_ec_base.h @@ -176,8 +176,15 @@ typedef struct ucc_ee_executor_task { ucc_ee_executor_t *eee; ucc_ee_executor_task_args_t args; ucc_status_t status; + void *completion; } ucc_ee_executor_task_t; +typedef struct node_ucc_ee_executor_task node_ucc_ee_executor_task_t; +typedef struct node_ucc_ee_executor_task { + ucc_ee_executor_task_t *etask; + node_ucc_ee_executor_task_t *next; +} node_ucc_ee_executor_task_t; + typedef struct ucc_ee_executor_ops { ucc_status_t (*init)(const ucc_ee_executor_params_t *params, ucc_ee_executor_t **executor); diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index d5a760a23a..074e4a1cd1 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -13,6 +13,7 @@ #include "coll_patterns/sra_knomial.h" #include "utils/ucc_math.h" #include "utils/ucc_coll_utils.h" +#include #define SAVE_STATE(_phase) \ do { \ @@ -50,6 +51,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) args->root : 0; ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); size_t local = GET_LOCAL_COUNT(args, size, rank); + ucp_mem_h *mh_list = task->allgather_kn.mh_list; void *sbuf; ptrdiff_t peer_seg_offset, local_seg_offset; ucc_rank_t peer, peer_dist; @@ -65,32 +67,36 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) if (KN_NODE_EXTRA == node_type) { peer = ucc_knomial_pattern_get_proxy(p, rank); if (p->type != KN_PATTERN_ALLGATHERX) { - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(task->allgather_kn.sbuf, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(task->allgather_kn.sbuf, local * dt_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer,broot,size)), - team, task), + team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); + } - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(rbuf, data_size, mem_type, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer,broot,size)), - team, task), + team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); } if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) { peer = ucc_knomial_pattern_get_extra(p, rank); extra_count = GET_LOCAL_COUNT(args, size, peer); peer = ucc_ep_map_eval(task->subset.map, peer); - UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(task->allgather_kn.sbuf, + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(task->allgather_kn.sbuf, local * dt_size), extra_count * dt_size, - mem_type, peer, team, task), + mem_type, peer, team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); } UCC_KN_PHASE_EXTRA: if ((KN_NODE_EXTRA == node_type) || (KN_NODE_PROXY == node_type)) { - if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) { SAVE_STATE(UCC_KN_PHASE_EXTRA); return; } @@ -114,12 +120,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) continue; } } - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(sbuf, local_seg_count * dt_size, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer, broot, size)), - team, task), + team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); } for (loop_step = 1; loop_step < radix; loop_step++) { @@ -137,15 +144,16 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) } } UCPCHECK_GOTO( - ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, peer_seg_offset * dt_size), + ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(rbuf, peer_seg_offset * dt_size), peer_seg_count * dt_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer, broot, size)), - team, task), + team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); } UCC_KN_PHASE_LOOP: - if (UCC_INPROGRESS == ucc_tl_ucp_test_recv(task)) { + if (UCC_INPROGRESS == ucc_tl_ucp_test_recv_with_etasks(task)) { SAVE_STATE(UCC_KN_PHASE_LOOP); return; } @@ -154,20 +162,22 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) if (KN_NODE_PROXY == node_type) { peer = ucc_knomial_pattern_get_extra(p, rank); - UCPCHECK_GOTO(ucc_tl_ucp_send_nb(args->dst.info.buffer, data_size, + UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(args->dst.info.buffer, data_size, mem_type, ucc_ep_map_eval(task->subset.map, INV_VRANK(peer, broot, size)), - team, task), + team, task, mh_list[task->allgather_kn.count_mh++]), task, out); + ucc_assert(task->allgather_kn.count_mh-1 <= task->allgather_kn.max_mh); } UCC_KN_PHASE_PROXY: - if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + if (UCC_INPROGRESS == ucc_tl_ucp_test_with_etasks(task)) { SAVE_STATE(UCC_KN_PHASE_PROXY); return; } out: + ucc_assert(task->allgather_kn.count_mh-1 == task->allgather_kn.max_mh); ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0); @@ -234,6 +244,155 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } +ucc_status_t register_memory(ucc_coll_task_t *coll_task){ + + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_coll_type_t ct = args->coll_type; + ucc_kn_radix_t radix = task->allgather_kn.p.radix; + uint8_t node_type = task->allgather_kn.p.node_type; + ucc_knomial_pattern_t *p = &task->allgather_kn.p; + void *rbuf = args->dst.info.buffer; + ucc_memory_type_t mem_type = args->dst.info.mem_type; + size_t count = args->dst.info.count; + size_t dt_size = ucc_dt_size(args->dst.info.datatype); + size_t data_size = count * dt_size; + ucc_rank_t size = task->subset.map.ep_num; + ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? + args->root : 0; + ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); + size_t local = GET_LOCAL_COUNT(args, size, rank); + void *sbuf; + ptrdiff_t peer_seg_offset, local_seg_offset; + ucc_rank_t peer, peer_dist; + ucc_kn_radix_t loop_step; + size_t peer_seg_count, local_seg_count; + ucc_status_t status; + size_t extra_count; + + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + ucp_mem_map_params_t mmap_params; + // ucp_mem_h mh; + int size_of_list = 1; + int count_mh = 0; + ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h)); + + UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0); + task->allgather_kn.etask = NULL; + task->allgather_kn.phase = UCC_KN_PHASE_INIT; + if (ct == UCC_COLL_TYPE_ALLGATHER) { + ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count, + &task->allgather_kn.p); + } else { + ucc_kn_agx_pattern_init(size, rank, radix, args->dst.info.count, + &task->allgather_kn.p); + } + + mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; + mmap_params.memory_type = ucc_memtype_to_ucs[mem_type]; + if (KN_NODE_EXTRA == node_type) { + if (p->type != KN_PATTERN_ALLGATHERX) { + mmap_params.address = task->allgather_kn.sbuf; + mmap_params.length = local * dt_size; + MEM_MAP(); + } + + mmap_params.address = rbuf; + mmap_params.length = data_size; + MEM_MAP(); + } + if ((p->type != KN_PATTERN_ALLGATHERX) && (node_type == KN_NODE_PROXY)) { + peer = ucc_knomial_pattern_get_extra(p, rank); + extra_count = GET_LOCAL_COUNT(args, size, peer); + peer = ucc_ep_map_eval(task->subset.map, peer); + mmap_params.address = PTR_OFFSET(task->allgather_kn.sbuf, + local * dt_size); + mmap_params.length = extra_count * dt_size; + MEM_MAP(); + } + + if (KN_NODE_EXTRA == node_type) { + goto out; + } + while (!ucc_knomial_pattern_loop_done(p)) { + ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count, + &local_seg_offset); + sbuf = PTR_OFFSET(rbuf, local_seg_offset * dt_size); + for (loop_step = radix - 1; loop_step > 0; loop_step--) { + peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); + if (peer == UCC_KN_PEER_NULL) + continue; + if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) { + peer_dist = ucc_knomial_calc_recv_dist(size - p->n_extra, + ucc_knomial_pattern_loop_rank(p, peer), p->radix, 0); + if (peer_dist < task->allgather_kn.recv_dist) { + continue; + } + } + mmap_params.address = sbuf; + mmap_params.length = local_seg_count * dt_size; + MEM_MAP(); + } + + for (loop_step = 1; loop_step < radix; loop_step++) { + peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step); + if (peer == UCC_KN_PEER_NULL) + continue; + ucc_kn_ag_pattern_peer_seg(peer, p, &peer_seg_count, + &peer_seg_offset); + + if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) { + peer_dist = ucc_knomial_calc_recv_dist(size - p->n_extra, + ucc_knomial_pattern_loop_rank(p, peer), p->radix, 0); + if (peer_dist > task->allgather_kn.recv_dist) { + continue; + } + } + mmap_params.address = PTR_OFFSET(rbuf, peer_seg_offset * dt_size); + mmap_params.length = peer_seg_count * dt_size; + MEM_MAP(); + } + ucc_kn_ag_pattern_next_iter(p); + } + + if (KN_NODE_PROXY == node_type) { + mmap_params.address = args->dst.info.buffer; + mmap_params.length = data_size; + MEM_MAP(); + } + +out: + task->allgather_kn.mh_list = mh_list; + task->allgather_kn.max_mh = count_mh-1; + task->allgather_kn.count_mh = 0; + return UCC_OK; +} + +ucc_status_t ucc_tl_ucp_allgather_knomial_finalize(ucc_coll_task_t *coll_task){ + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, + ucc_tl_ucp_task_t); + ucc_status_t status; + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); + + ucc_mpool_cleanup(&task->allgather_kn.etask_node_mpool, 1); + for (int i=0; iallgather_kn.max_mh+1; i++){ + ucp_mem_unmap(ctx->worker.ucp_context, task->allgather_kn.mh_list[i]); + } + free(task->allgather_kn.mh_list); + status = ucc_tl_ucp_coll_finalize(&task->super); + if (status < 0){ + tl_error(UCC_TASK_LIB(task), + "failed to initialize ucc_mpool"); + } + + return UCC_OK; +} + ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h, ucc_kn_radix_t radix) @@ -241,18 +400,34 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); ucc_tl_ucp_task_t *task; ucc_sbgp_t *sbgp; + ucc_status_t status; task = ucc_tl_ucp_init_task(coll_args, team); + status = ucc_mpool_init(&task->allgather_kn.etask_node_mpool, 0, sizeof(node_ucc_ee_executor_task_t), + 0, UCC_CACHE_LINE_SIZE, 16, UINT_MAX, NULL, + tl_team->super.super.context->ucc_context->thread_mode, "etasks_linked_list_nodes"); + if (status < 0){ + tl_error(UCC_TASK_LIB(task), + "failed to initialize ucc_mpool"); + } + if (tl_team->cfg.use_reordering && coll_args->args.coll_type == UCC_COLL_TYPE_ALLREDUCE) { sbgp = ucc_topo_get_sbgp(tl_team->topo, UCC_SBGP_FULL_HOST_ORDERED); task->subset.myrank = sbgp->group_rank; task->subset.map = sbgp->map; } + task->allgather_kn.etask_linked_list_head = NULL; task->allgather_kn.p.radix = radix; task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; task->super.post = ucc_tl_ucp_allgather_knomial_start; task->super.progress = ucc_tl_ucp_allgather_knomial_progress; + task->super.finalize = ucc_tl_ucp_allgather_knomial_finalize; + status = register_memory(&task->super); + if (status < 0){ + tl_error(UCC_TASK_LIB(task), + "failed to register memory"); + } *task_h = &task->super; return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 414c6a04eb..4347ab2874 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -57,6 +57,17 @@ void ucc_tl_ucp_team_default_score_str_free( } \ } while(0) +#define MEM_MAP() do { \ + status = ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh_list[count_mh++])); \ + if (UCC_OK != status) { \ + return status; \ + } \ + if (count_mh == size_of_list){ \ + size_of_list *= 2; \ + mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); \ + } \ +} while(0) + #define EXEC_TASK_WAIT(_etask, ...) \ do { \ if (_etask != NULL) { \ @@ -183,7 +194,12 @@ typedef struct ucc_tl_ucp_task { ucc_knomial_pattern_t p; void *sbuf; ucc_ee_executor_task_t *etask; + node_ucc_ee_executor_task_t *etask_linked_list_head; ucc_rank_t recv_dist; + ucc_mpool_t etask_node_mpool; + ucp_mem_h *mh_list; + int count_mh; + int max_mh; } allgather_kn; struct { /* @@ -406,6 +422,48 @@ static inline ucc_status_t ucc_tl_ucp_test(ucc_tl_ucp_task_t *task) return UCC_INPROGRESS; } +static inline ucc_status_t ucc_tl_ucp_test_with_etasks(ucc_tl_ucp_task_t *task) +{ + int polls = 0; + ucc_status_t status; + ucc_status_t status_2; + node_ucc_ee_executor_task_t *current_node; + node_ucc_ee_executor_task_t *prev_node; + + if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) { + return UCC_OK; + } + while (polls++ < task->n_polls) { + current_node = task->allgather_kn.etask_linked_list_head; + prev_node = NULL; + while(current_node != NULL) { + status = ucc_ee_executor_task_test(current_node->etask); + if (status > 0) { + ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status)); + status_2 = ucc_ee_executor_task_finalize(current_node->etask); + ucc_mpool_put(current_node); + if (ucc_unlikely(status_2 < 0)){ + tl_error(UCC_TASK_LIB(task), "task finalize didnt work"); + return status_2; + } + if (prev_node != NULL){ + prev_node->next = current_node->next; //to remove from list + } + else{ //i'm on first node + task->allgather_kn.etask_linked_list_head = current_node->next; + } + } + prev_node = current_node; + current_node = current_node->next; //to iterate to next node + } + if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && task->allgather_kn.etask_linked_list_head == NULL) { + return UCC_OK; + } + ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); + } + return UCC_INPROGRESS; +} + #define UCC_TL_UCP_TASK_RECV_COMPLETE(_task) \ (((_task)->tagged.recv_posted == (_task)->tagged.recv_completed)) @@ -428,6 +486,47 @@ static inline ucc_status_t ucc_tl_ucp_test_recv(ucc_tl_ucp_task_t *task) return UCC_INPROGRESS; } +static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *task) { + int polls = 0; + ucc_status_t status; + ucc_status_t status_2; + node_ucc_ee_executor_task_t *current_node; + node_ucc_ee_executor_task_t *prev_node; + + if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) { + return UCC_OK; + } + while (polls++ < task->n_polls) { + current_node = task->allgather_kn.etask_linked_list_head; + prev_node = NULL; + while(current_node != NULL) { + status = ucc_ee_executor_task_test(current_node->etask); + if (status > 0) { + ucp_memcpy_device_complete(current_node->etask->completion, ucc_status_to_ucs_status(status)); + status_2 = ucc_ee_executor_task_finalize(current_node->etask); + ucc_mpool_put(current_node); + if (ucc_unlikely(status_2 < 0)){ + tl_error(UCC_TASK_LIB(task), "task finalize didnt work"); + return status_2; + } + if (prev_node != NULL){ + prev_node->next = current_node->next; //to remove from list + } + else{ //i'm on first node + task->allgather_kn.etask_linked_list_head = current_node->next; + } + } + prev_node = current_node; + current_node = current_node->next; //to iterate to next node + } + if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) { + return UCC_OK; + } + ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); + } + return UCC_INPROGRESS; +} + static inline ucc_status_t ucc_tl_ucp_test_send(ucc_tl_ucp_task_t *task) { int polls = 0; diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 1c7c49b53f..889de8a798 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -133,6 +133,85 @@ ucc_tl_ucp_context_service_init(const char *prefix, ucp_params_t ucp_params, return ucc_status; } +static int memcpy_device_start(void *dest, void *src, size_t size, + void *completion, void *user_data) { + + ucc_status_t status; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *) user_data; + + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = src; + eargs.copy.dst = dest; + eargs.copy.len = size; + node_ucc_ee_executor_task_t *new_node; + new_node = ucc_mpool_get(&task->allgather_kn.etask_node_mpool); + if (ucc_unlikely(!new_node)) { + return UCC_ERR_NO_MEMORY; + } + status = ucc_ee_executor_task_post(exec, &eargs, + &new_node->etask); + task->allgather_kn.etask_linked_list_head->etask->completion = completion; + + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } + new_node->next = task->allgather_kn.etask_linked_list_head; + task->allgather_kn.etask_linked_list_head = new_node; + + return 1; + + } + +static int memcpy_device(void *dest, void *src, size_t size, void *user_data){ + + ucc_status_t status; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; + ucc_ee_executor_task_t *etask; + ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *) user_data; + + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = src; + eargs.copy.dst = dest; + eargs.copy.len = size; + + status = ucc_ee_executor_task_post(exec, &eargs, &etask); + if (ucc_unlikely(status < 0)) { + return status; + } + status = ucc_ee_executor_task_test(etask); + while (status>0) { + status = ucc_ee_executor_task_test(etask); + if (ucc_unlikely(status < 0)) { + return status; + } + } + status = ucc_ee_executor_task_finalize(etask); + if (ucc_unlikely(status < 0)) { + return status; + } + return 1; +} + +ucp_worker_mem_callbacks_t copy_callback = +{ + .memcpy_device_start = memcpy_device_start, + .memcpy_device = memcpy_device +}; + UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, const ucc_base_context_params_t *params, const ucc_base_config_t *config) @@ -194,7 +273,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, self); self->ucp_memory_types = context_attr.memory_types; - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE | UCP_WORKER_PARAM_FIELD_CALLBACKS; switch (params->thread_mode) { case UCC_THREAD_SINGLE: case UCC_THREAD_FUNNELED: @@ -209,6 +288,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, break; } + worker_params.callbacks = copy_callback; + UCP_CHECK(ucp_worker_create(ucp_context, &worker_params, &ucp_worker), "failed to create ucp worker", err_worker_create, self); diff --git a/src/components/tl/ucp/tl_ucp_sendrecv.h b/src/components/tl/ucp/tl_ucp_sendrecv.h index 4d1bc84440..53bf2cc48c 100644 --- a/src/components/tl/ucp/tl_ucp_sendrecv.h +++ b/src/components/tl/ucp/tl_ucp_sendrecv.h @@ -94,6 +94,36 @@ ucc_tl_ucp_send_common(void *buffer, size_t msglen, ucc_memory_type_t mtype, return ucp_tag_send_nbx(ep, buffer, 1, ucp_tag, &req_param); } +static inline ucs_status_ptr_t +ucc_tl_ucp_send_common_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task, ucp_send_nbx_callback_t cb, void *user_data, ucp_mem_h mh) +{ + ucc_coll_args_t *args = &TASK_ARGS(task); + ucp_request_param_t req_param; + ucc_status_t status; + ucp_ep_h ep; + ucp_tag_t ucp_tag; + + status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep); + if (ucc_unlikely(UCC_OK != status)) { + return UCS_STATUS_PTR(UCS_ERR_NO_MESSAGE); + } + ucp_tag = UCC_TL_UCP_MAKE_SEND_TAG((args->mask & UCC_COLL_ARGS_FIELD_TAG), + task->tagged.tag, UCC_TL_TEAM_RANK(team), team->super.super.params.id, + team->super.super.params.scope_id, team->super.super.params.scope); + req_param.op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FIELD_MEMORY_TYPE | UCP_OP_ATTR_FIELD_MEMH; + req_param.datatype = ucp_dt_make_contig(msglen); + req_param.cb.send = cb; + req_param.memory_type = ucc_memtype_to_ucs[mtype]; + req_param.user_data = user_data; + req_param.memh = mh; + task->tagged.send_posted++; + return ucp_tag_send_nbx(ep, buffer, 1, ucp_tag, &req_param); +} + static inline ucc_status_t ucc_tl_ucp_send_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype, ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, @@ -112,6 +142,24 @@ ucc_tl_ucp_send_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype, return UCC_OK; } +static inline ucc_status_t +ucc_tl_ucp_send_nb_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task, ucp_mem_h mh) +{ + ucs_status_ptr_t ucp_status; + + ucp_status = ucc_tl_ucp_send_common_with_mem(buffer, msglen, mtype, dest_group_rank, + team, task, ucc_tl_ucp_send_completion_cb, + (void *)task, mh); + if (UCS_OK != ucp_status) { + UCC_TL_UCP_CHECK_REQ_STATUS(); + } else { + ucc_atomic_add32(&task->tagged.send_completed, 1); + } + return UCC_OK; +} + static inline ucc_status_t ucc_tl_ucp_send_cb(void *buffer, size_t msglen, ucc_memory_type_t mtype, ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, @@ -157,6 +205,35 @@ ucc_tl_ucp_recv_common(void *buffer, size_t msglen, ucc_memory_type_t mtype, ucp_tag_mask, &req_param); } +static inline ucs_status_ptr_t +ucc_tl_ucp_recv_common_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task, ucp_tag_recv_nbx_callback_t cb, void *user_data, ucp_mem_h mh) +{ + ucc_coll_args_t *args = &TASK_ARGS(task); + ucp_request_param_t req_param; + ucp_tag_t ucp_tag, ucp_tag_mask; + + // coverity[result_independent_of_operands:FALSE] + UCC_TL_UCP_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, + (args->mask & UCC_COLL_ARGS_FIELD_TAG), + task->tagged.tag, dest_group_rank, + team->super.super.params.id, + team->super.super.params.scope_id, + team->super.super.params.scope); + req_param.op_attr_mask = + UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FIELD_MEMORY_TYPE | UCP_OP_ATTR_FIELD_MEMH; + req_param.datatype = ucp_dt_make_contig(msglen); + req_param.cb.recv = cb; + req_param.memory_type = ucc_memtype_to_ucs[mtype]; + req_param.user_data = user_data; + req_param.memh = mh; + task->tagged.recv_posted++; + return ucp_tag_recv_nbx(team->worker->ucp_worker, buffer, 1, ucp_tag, + ucp_tag_mask, &req_param); +} + static inline ucc_status_t ucc_tl_ucp_recv_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype, ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, @@ -176,6 +253,25 @@ ucc_tl_ucp_recv_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype, } +static inline ucc_status_t +ucc_tl_ucp_recv_nb_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype, + ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task, ucp_mem_h mh) +{ + ucs_status_ptr_t ucp_status; + + ucp_status = ucc_tl_ucp_recv_common_with_mem(buffer, msglen, mtype, dest_group_rank, + team, task, ucc_tl_ucp_recv_completion_cb, + (void *)task, mh); + if (UCS_OK != ucp_status) { + UCC_TL_UCP_CHECK_REQ_STATUS(); + } else { + ucc_atomic_add32(&task->tagged.recv_completed, 1); + } + return UCC_OK; + +} + static inline ucc_status_t ucc_tl_ucp_recv_cb(void *buffer, size_t msglen, ucc_memory_type_t mtype, ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team,