Skip to content

Commit

Permalink
TL/MLX5: add nonblocking cudaMemcpy support
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Oct 24, 2024
1 parent 070eb64 commit 363a044
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 20 deletions.
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
int last_pkt_len;
int offset;
ucc_memory_type_t buf_mem_type;
ucc_ee_executor_task_t *exec_task;
ucc_coll_task_t *coll_task;
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand Down Expand Up @@ -427,6 +429,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
return UCC_OK;
}

#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \
if (_etask != NULL) { \
status = ucc_ee_executor_task_test(_etask); \
if (status > 0) { \
return status; \
} \
ucc_ee_executor_task_finalize(_etask); \
_etask = NULL; \
if (ucc_unlikely(status < 0)) { \
tl_error(_lib, _errmsg); \
return status; \
} \
} \
} while(0)

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down
35 changes: 33 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_
return status;
}

if (comm->cuda_mem_enabled) {
while (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib);
}
}

comm->n_mcast_reliable++;

for (;comm->last_acked < comm->psn; comm->last_acked++) {
Expand Down Expand Up @@ -270,7 +276,8 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)
return ucc_task_complete(coll_task);
}

coll_task->status = status;
coll_task->status = status;
task->bcast_mcast.req_handle->coll_task = coll_task;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);
}
Expand Down Expand Up @@ -329,10 +336,34 @@ ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task,
ucc_base_coll_args_t *coll_args)
{
ucc_coll_args_t *args = &coll_args->args;

task->super.post = ucc_tl_mlx5_mcast_bcast_start;
task->super.progress = ucc_tl_mlx5_mcast_collective_progress;
if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) {
task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR;
}

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task)
{
return ucc_schedule_start(coll_task);
}

ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task)
{
ucc_status_t status;
ucc_tl_mlx5_schedule_t *schedule =
ucc_derived_of(coll_task, ucc_tl_mlx5_schedule_t);

status = ucc_schedule_finalize(coll_task);

ucc_tl_mlx5_put_schedule(schedule);
return status;
}

7 changes: 6 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
#include "tl_mlx5_mcast.h"
#include "tl_mlx5_coll.h"

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);
ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task,
ucc_base_coll_args_t *coll_args);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);

ucc_status_t ucc_tl_mlx5_mcast_schedule_start(ucc_coll_task_t *coll_task);

ucc_status_t ucc_tl_mlx5_mcast_schedule_finalize(ucc_coll_task_t *coll_task);
#endif
38 changes: 25 additions & 13 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
ucc_tl_mlx5_mcast_coll_req_t *req,
struct pp_packet* pp)
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_status_t status = UCC_OK;
void *dest;
ucc_ee_executor_task_args_t eargs;
ucc_ee_executor_t *exec;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);

Expand All @@ -409,18 +410,29 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}
while (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib);
}

status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
/* for cuda memcpy use nonblocking copy */
status = ucc_coll_task_get_executor(req->coll_task, &exec);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = (void*) pp->buf;
eargs.copy.dst = dest;
eargs.copy.len = pp->length;

assert(req->exec_task == NULL);
status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
} else {
memcpy(dest, (void*) pp->buf, pp->length);
}
}

Expand Down
36 changes: 32 additions & 4 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;
ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_task_t *task = NULL;
ucc_coll_task_t *bcast_task;
ucc_schedule_t *schedule;


status = ucc_tl_mlx5_mcast_check_support(coll_args, team);
if (UCC_OK != status) {
Expand All @@ -27,12 +31,36 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,

task->super.finalize = ucc_tl_mlx5_task_finalize;

status = ucc_tl_mlx5_mcast_bcast_init(task);
status = ucc_tl_mlx5_get_schedule(tl_team, coll_args,
(ucc_tl_mlx5_schedule_t **)&schedule);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

status = ucc_tl_mlx5_mcast_bcast_init(task, coll_args);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

bcast_task = &(task->super);

status = ucc_schedule_add_task(schedule, bcast_task);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

status = ucc_event_manager_subscribe(&schedule->super,
UCC_EVENT_SCHEDULE_STARTED,
bcast_task,
ucc_task_start_handler);
if (ucc_unlikely(UCC_OK != status)) {
goto free_task;
}

*task_h = &(task->super);
schedule->super.post = ucc_tl_mlx5_mcast_schedule_start;
schedule->super.progress = NULL;
schedule->super.finalize = ucc_tl_mlx5_mcast_schedule_finalize;
*task_h = &schedule->super;

tl_debug(UCC_TASK_LIB(task), "init coll task %p", task);

Expand Down

0 comments on commit 363a044

Please sign in to comment.