diff --git a/src/components/tl/mlx5/Makefile.am b/src/components/tl/mlx5/Makefile.am index a7bc249f87..f1ab2ee78e 100644 --- a/src/components/tl/mlx5/Makefile.am +++ b/src/components/tl/mlx5/Makefile.am @@ -29,6 +29,8 @@ mcast = \ mcast/tl_mlx5_mcast_service_coll.c \ mcast/tl_mlx5_mcast_one_sided_reliability.h \ mcast/tl_mlx5_mcast_one_sided_reliability.c \ + mcast/tl_mlx5_mcast_allgather.h \ + mcast/tl_mlx5_mcast_allgather.c \ mcast/tl_mlx5_mcast_team.c sources = \ diff --git a/src/components/tl/mlx5/alltoall/alltoall.c b/src/components/tl/mlx5/alltoall/alltoall.c index 5afc7c7d30..7b1f2e3ce6 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.c +++ b/src/components/tl/mlx5/alltoall/alltoall.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/alltoall/alltoall.h b/src/components/tl/mlx5/alltoall/alltoall.h index c2bb39b62f..9fd9d787cc 100644 --- a/src/components/tl/mlx5/alltoall/alltoall.h +++ b/src/components/tl/mlx5/alltoall/alltoall.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c index 0fa197e6c7..e8b3052501 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.c +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/alltoall/alltoall_mkeys.h b/src/components/tl/mlx5/alltoall/alltoall_mkeys.h index 0ea7b38a0c..e2e8432dc4 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_mkeys.h +++ b/src/components/tl/mlx5/alltoall/alltoall_mkeys.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 734acf1f30..663ee636ed 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -33,14 +33,27 @@ #define GRH_LENGTH 40 #define DROP_THRESHOLD 10000 #define MAX_COMM_POW2 32 +#define MAX_GROUP_COUNT 64 /* Allgather RDMA-based reliability designs */ -#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024 -#define ONE_SIDED_NO_RELIABILITY 0 -#define ONE_SIDED_SYNCHRONOUS_PROTO 1 -#define ONE_SIDED_ASYNCHRONOUS_PROTO 2 +#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024u #define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */ #define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */ +#define ONE_SIDED_MAX_ALLGATHER_COUNTER 32u +#define ONE_SIDED_MAX_CONCURRENT_LEVEL 64 + +enum ucc_tl_mlx5_mcast_one_sided_slot_states { + ONE_SIDED_INVALID = -4, + ONE_SIDED_VALID, + ONE_SIDED_PENDING_INFO, + ONE_SIDED_PENDING_DATA, +}; + +enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme { + ONE_SIDED_NO_RELIABILITY = 0, + ONE_SIDED_SYNCHRONOUS_PROTO, + ONE_SIDED_ASYNCHRONOUS_PROTO +}; #define CUDA_MEM_MCAST_BCAST_MAX_MSG 4000 @@ -90,7 +103,7 @@ typedef struct ucc_tl_mlx5_mcast_p2p_interface { ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t recv_nb; } ucc_tl_mlx5_mcast_p2p_interface_t; -typedef struct mcast_coll_comm_init_spec { +typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec { ucc_tl_mlx5_mcast_p2p_interface_t p2p_iface; int sx_depth; int rx_depth; @@ -100,8 +113,10 @@ typedef struct mcast_coll_comm_init_spec { int post_recv_thresh; int scq_moderation; int wsize; + int max_push_send; int max_eager; int cuda_mem_enabled; + int one_sided_reliability_enable; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -152,6 +167,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { struct rdma_cm_id *id; struct rdma_event_channel *channel; ucc_mpool_t compl_objects_mp; + ucc_mpool_t mcast_req_mp; ucc_list_link_t pending_nacks_list; ucc_rcache_t *rcache; ucc_tl_mlx5_mcast_ctx_params_t params; @@ -168,7 +184,6 @@ typedef struct ucc_tl_mlx5_mcast_context { ucc_thread_mode_t tm; ucc_tl_mlx5_mcast_coll_context_t mcast_context; ucc_tl_mlx5_mcast_context_config_t cfg; - ucc_mpool_t req_mp; int mcast_enabled; int mcast_ctx_ready; ucc_tl_mlx5_mcast_oob_ctx_t oob_ctx; @@ -178,19 +193,28 @@ struct pp_packet { ucc_list_link_t super; uint32_t psn; int length; + int packet_counter; uintptr_t context; - uintptr_t buf; + int qp_id; + uintptr_t buf; // buffer address, initialized once }; struct mcast_ctx { - struct ibv_qp *qp; - struct ibv_ah *ah; - struct ibv_send_wr swr; - struct ibv_sge ssg; + struct ibv_qp *qp; + struct ibv_ah *ah; + struct ibv_send_wr swr; + struct ibv_sge ssg; + // RC connection info for supporing one-sided based relibality struct ibv_qp **rc_qp; uint16_t *rc_lid; union ibv_gid *rc_gid; + + // multiple mcast group + struct ibv_qp **qp_list; + struct ibv_ah **ah_list; + struct ibv_send_wr *swr_list; + struct ibv_sge *ssg_list; }; struct packet { @@ -219,21 +243,25 @@ typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm { /* holds all the remote-addr/rkey of sendbuf from processes in the team * used in sync design. it needs to be set during each mcast-allgather call * after sendbuf registration */ - ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list; + ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list; /* counter for each target recv packet */ - uint32_t *recvd_pkts_tracker; + uint32_t *recvd_pkts_tracker; /* holds the remote targets' collective call counter. it is used to check * if remote temp slot is ready for RDMA READ in async design */ - uint32_t *remote_slot_info; - struct ibv_mr *remote_slot_info_mr; - int reliability_scheme_msg_threshold; + uint32_t *remote_slot_info; + struct ibv_mr *remote_slot_info_mr; + int reliability_scheme_msg_threshold; /* mem address and mem keys of the temp slots in async design */ - char *slots_buffer; - struct ibv_mr *slots_mr; + char *slots_buffer; + struct ibv_mr *slots_mr; /* size of a temp slot in async design */ - int slot_size; + int slot_size; /* coll req that is used during the oob service calls */ - ucc_service_coll_req_t *reliability_req; + ucc_service_coll_req_t *reliability_req; + int reliability_enabled; + int reliability_ready; + int rdma_read_in_progress; + enum ucc_tl_mlx5_mcast_one_sided_slot_states slots_state; } ucc_tl_mlx5_mcast_one_sided_reliability_comm_t; typedef struct ucc_tl_mlx5_mcast_service_coll { @@ -243,6 +271,32 @@ typedef struct ucc_tl_mlx5_mcast_service_coll { ucc_status_t (*coll_test) (ucc_service_coll_req_t*); } ucc_tl_mlx5_mcast_service_coll_t; +typedef struct ucc_tl_mlx5_mcast_allgather_comm { + uint32_t under_progress_counter; + uint32_t coll_counter; + uint32_t max_num_packets; + uint32_t max_push_send; +} ucc_tl_mlx5_mcast_allgather_comm_t; + +typedef struct ucc_tl_mlx5_mcast_bcast_comm { + uint32_t last_psn; + uint32_t racks_n; + uint32_t sacks_n; + uint32_t last_acked; + uint32_t child_n; + uint32_t parent_n; + struct packet p2p_pkt[MAX_COMM_POW2]; + struct packet p2p_spkt[MAX_COMM_POW2]; + int reliable_in_progress; + int recv_drop_packet_in_progress; + ucc_rank_t parents[MAX_COMM_POW2]; + ucc_rank_t children[MAX_COMM_POW2]; + int nack_requests; + int nacks_counter; + int n_mcast_reliable; + int wsize; +} ucc_tl_mlx5_mcast_bcast_comm_t; + typedef struct ucc_tl_mlx5_mcast_coll_comm { struct pp_packet dummy_packet; ucc_tl_mlx5_mcast_coll_context_t *ctx; @@ -268,21 +322,11 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_mc_buffer_header_t *pp_buf_header; struct pp_packet *pp; uint32_t psn; - uint32_t last_psn; - uint32_t racks_n; - uint32_t sacks_n; - uint32_t last_acked; - uint32_t naks_n; - uint32_t child_n; - uint32_t parent_n; int buf_n; - struct packet p2p_pkt[MAX_COMM_POW2]; - struct packet p2p_spkt[MAX_COMM_POW2]; ucc_list_link_t bpool; ucc_list_link_t pending_q; + ucc_list_link_t posted_q; struct mcast_ctx mcast; - int reliable_in_progress; - int recv_drop_packet_in_progress; struct ibv_recv_wr *call_rwr; struct ibv_sge *call_rsgs; uint64_t timer; @@ -291,19 +335,16 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { void *p2p_ctx; ucc_base_lib_t *lib; struct sockaddr_in6 mcast_addr; - ucc_rank_t parents[MAX_COMM_POW2]; - ucc_rank_t children[MAX_COMM_POW2]; - int nack_requests; - int nacks_counter; - int n_prep_reliable; - int n_mcast_reliable; - int wsize; int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; ucc_tl_mlx5_mcast_service_coll_t service_coll; struct rdma_cm_event *event; ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided; + int mcast_group_count; + int pending_recv_per_qp[MAX_GROUP_COUNT]; + ucc_tl_mlx5_mcast_allgather_comm_t allgather_comm; + ucc_tl_mlx5_mcast_bcast_comm_t bcast_comm; struct pp_packet *r_window[1]; // note: do not add any new variable after here } ucc_tl_mlx5_mcast_coll_comm_t; @@ -320,54 +361,86 @@ typedef struct ucc_tl_mlx5_mcast_nack_req { ucc_tl_mlx5_mcast_coll_comm_t *comm; } ucc_tl_mlx5_mcast_nack_req_t; -#define PSN_IS_IN_RANGE(_psn, _call, _comm) \ - ( \ - ((_psn >= _call->start_psn) && \ - (_psn < _call->start_psn + _call->num_packets) && \ - (_psn >= _comm->last_acked) && \ - (_psn < _comm->last_acked + _comm->wsize)) \ +#define PSN_IS_IN_RANGE(_psn, _call, _comm) \ + ( \ + ((_psn >= _call->start_psn) && \ + (_psn < _call->start_psn + _call->num_packets) && \ + (_psn >= _comm->bcast_comm.last_acked) && \ + (_psn < _comm->bcast_comm.last_acked + _comm->bcast_comm.wsize)) \ ) -#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \ - ( \ - ((ptrdiff_t)((_psn - _call->start_psn) \ - * (_comm->max_per_packet))) \ +#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \ + ( \ + ((ptrdiff_t)((_psn - _call->start_psn) \ + * (_comm->max_per_packet))) \ ) -#define PSN_TO_RECV_LEN(_psn, _call, _comm) \ - ( \ - ((_psn - _call->start_psn + 1) % \ - _call->num_packets == 0 ? _call->last_pkt_len : \ - _comm->max_per_packet) \ +#define PSN_TO_RECV_LEN(_psn, _call, _comm) \ + ( \ + ((_psn - _call->start_psn + 1) % \ + _call->num_packets == 0 ? _call->last_pkt_len : \ + _comm->max_per_packet) \ ) -#define PSN_RECEIVED(_psn, _comm) \ - ( \ - (_comm->r_window[(_psn) % \ - _comm->wsize]->psn == (_psn)) \ +#define PSN_RECEIVED(_psn, _comm) \ + ( \ + (_comm->r_window[(_psn) % \ + _comm->bcast_comm.wsize]->psn == (_psn)) \ ) +typedef struct ucc_tl_mlx5_mcast_tensor { + int group_id; + size_t offset; + size_t offset_left; + int root; + int count; + int to_recv; + int to_send_left; +} ucc_tl_mlx5_mcast_tensor_t; + +typedef struct ucc_tl_mlx5_mcast_pipelined_ag_schedule { + ucc_tl_mlx5_mcast_tensor_t multicast_op[ONE_SIDED_MAX_CONCURRENT_LEVEL]; + ucc_tl_mlx5_mcast_tensor_t prepost_buf_op[ONE_SIDED_MAX_CONCURRENT_LEVEL]; + int prepost_buf_op_done; + int multicast_op_done; + int total_steps; + int num_recvd; + int to_recv; + int to_send; +} ucc_tl_mlx5_mcast_pipelined_ag_schedule_t; + typedef struct ucc_tl_mlx5_mcast_coll_req { - ucc_tl_mlx5_mcast_coll_comm_t *comm; - size_t length; /* bcast buffer size */ - int proto; - struct ibv_mr *mr; - struct ibv_recv_wr *rwr; - struct ibv_sge *rsgs; - void *rreg; - char *ptr; - int am_root; - ucc_rank_t root; - void **rbufs; - int first_send_psn; - int to_send; - int to_recv; - ucc_rank_t parent; - uint32_t start_psn; - int num_packets; - int last_pkt_len; - int offset; - ucc_memory_type_t buf_mem_type; + ucc_tl_mlx5_mcast_coll_comm_t *comm; + size_t length; + int proto; + struct ibv_mr *mr; + struct ibv_mr *recv_mr; + struct ibv_recv_wr *rwr; + struct ibv_sge *rsgs; + void *rreg; + char *ptr; + char *rptr; + int am_root; + ucc_rank_t root; + void **rbufs; + int first_send_psn; + int to_send; + int to_recv; + ucc_rank_t parent; + uint32_t start_psn; + int num_packets; + int last_pkt_len; + int offset; + ucc_memory_type_t buf_mem_type; + enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme; + uint32_t ag_counter; + int state; + ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule; + int total_steps; + int step; + ucc_service_coll_req_t *allgather_rkeys_req; + ucc_service_coll_req_t *barrier_req; + void *recv_rreg; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -427,6 +500,61 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast return UCC_OK; } +static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) +{ + double t_second = ucc_get_time(); + return (uint64_t) (t_second * 1000000); +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + int group_id, ucc_rank_t root, + int coll_type, + int count, + size_t offset) +{ + struct ibv_recv_wr *bad_wr = NULL; + struct ibv_recv_wr *rwr = comm->call_rwr; + struct ibv_sge *sge = comm->call_rsgs; + struct pp_packet *pp = NULL; + uint32_t i; + + for (i = 0; i < count; i++) { + if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { + tl_error(comm->lib, "not enought free pp packets to cover the entire message"); + return UCC_ERR_NO_RESOURCE; + } + + assert(offset % comm->max_per_packet == 0); + pp->packet_counter = offset / comm->max_per_packet; + pp->qp_id = group_id; + rwr[i].wr_id = ((uint64_t) pp); + sge[2*i + 1].addr = (uint64_t)req->rptr + root * req->length + offset; + sge[2*i + 1].lkey = req->recv_mr->lkey; + offset += comm->max_per_packet; + + if (i == count - 1) { + sge[2*i + 1].length = req->last_pkt_len; + } else { + sge[2*i + 1].length = comm->max_per_packet; + rwr[i].next = &rwr[i+1]; + } + } + + if (i > 0) { + rwr[i-1].next = NULL; + if (ibv_post_recv(comm->mcast.qp_list[group_id], &rwr[0], &bad_wr)) { + tl_error(comm->lib, "Failed to prepost recvs: errno %d buffer count %d", + errno, i); + return UCC_ERR_NO_RESOURCE; + } + comm->pending_recv += i; + comm->pending_recv_per_qp[group_id] += i; + } + + return UCC_OK; +} + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c new file mode 100644 index 0000000000..82592238d4 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c @@ -0,0 +1,375 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_helper.h" +#include "tl_mlx5_coll.h" +#include "tl_mlx5_mcast_rcache.h" +#include "tl_mlx5_mcast_progress.h" +#include "tl_mlx5_mcast_allgather.h" +#include + +/* 32 here is the bit count of ib mcast packet's immediate data */ +#define TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT 32u + +#define MCAST_GET_MAX_ALLGATHER_PACKET_COUNT(_max_count, _max_team, _max_counter) \ +do { \ + _max_count = 2 << (TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT - \ + ucc_ilog2(_max_team) - \ + ucc_ilog2(_max_counter)); \ +} while (0); + +#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \ + (_req->to_send || _req->to_recv || _comm->pending_send || \ + _comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \ + +static inline ucc_status_t ucc_tl_mlx5_mcast_check_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status; + + ucc_assert(comm->one_sided.reliability_ready); + + if (comm->one_sided.rdma_read_in_progress) { + return ucc_tl_mlx5_mcast_progress_one_sided_communication(comm, req); + } + + /* check if remote rkey/address have arrived - applicable for sync design */ + if (req->allgather_rkeys_req != NULL) { + status = comm->service_coll.coll_test(req->allgather_rkeys_req); + if (status == UCC_OK) { + ucc_assert(ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme); + req->allgather_rkeys_req = NULL; + tl_trace(comm->lib, "Allgather for remote_addr/rkey is completed"); + } else if (status < 0) { + tl_error(comm->lib, "Allgather for remote_addr/rkey failed"); + return status; + } + } + + if (!req->to_send && !req->to_recv) { + // all have been received, nothing to do + return UCC_OK; + + } else if (req->to_send) { + // it is not yet the time to start the reliability protocol + return UCC_INPROGRESS; + } + + if (!comm->timer) { + if (comm->stalled < DROP_THRESHOLD) { + comm->stalled++; + } else { + // kick the timer + comm->timer = ucc_tl_mlx5_mcast_get_timer(); + comm->stalled = 0; + } + } else { + if (comm->stalled < DROP_THRESHOLD || (NULL != req->allgather_rkeys_req)) { + comm->stalled++; + } else { + // calcuate the current time and check if it's time to do RDMA READ + if (ucc_tl_mlx5_mcast_get_timer() - comm->timer >= + comm->ctx->params.timeout) { + tl_debug(comm->lib, "[REL] time out req->to_recv %d left out of total of %d packets", + req->to_recv, req->num_packets * comm->commsize); + status = ucc_tl_mlx5_mcast_reliable_one_sided_get(comm, req, NULL); + if (UCC_OK != status) { + return status; + } + } else { + comm->stalled = 0; + } + } + } + + return UCC_INPROGRESS; +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_reset_reliablity(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; + ucc_tl_mlx5_mcast_reg_t *reg = NULL; + ucc_status_t status; + + ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter); + + if (comm->one_sided.reliability_enabled && !comm->one_sided.reliability_ready) { + /* initialize the structures needed by reliablity protocol */ + memset(comm->one_sided.recvd_pkts_tracker, 0, comm->commsize * sizeof(uint32_t)); + memset(comm->one_sided.remote_slot_info, ONE_SIDED_INVALID, comm->commsize * sizeof(uint32_t)); + /* local slots state */ + comm->one_sided.slots_state = ONE_SIDED_INVALID; + + if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) { + /* do nonblocking allgather over remote addresses/keys */ + if (!req->rreg) { + /* register sbuf if it is not registered before */ + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, ®); + if (UCC_OK != status) { + return status; + } + req->rreg = reg; + req->mr = reg->mr; + } + + comm->one_sided.sendbuf_memkey_list[comm->rank].rkey = req->mr->rkey; + comm->one_sided.sendbuf_memkey_list[comm->rank].remote_addr = (uint64_t)req->ptr; + + tl_trace(comm->lib, "Allgather over sendbuf addresses/rkey: address %p rkey %d", + req->ptr, req->mr->rkey); + + status = comm->service_coll.allgather_post(comm->p2p_ctx, NULL /* in-place */, + comm->one_sided.sendbuf_memkey_list, + sizeof(ucc_tl_mlx5_mcast_slot_mem_info_t), + &req->allgather_rkeys_req); + if (UCC_OK != status) { + tl_error(comm->lib, "oob allgather failed during one-sided reliability reset of a collective call"); + return status; + } + } + + memset(comm->pending_recv_per_qp, 0, sizeof(int) * MAX_GROUP_COUNT); + comm->one_sided.reliability_ready = 1; + } + + return UCC_OK; +} + +static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; + char *dest; + + ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter); + + if (ONE_SIDED_ASYNCHRONOUS_PROTO == req->one_sided_reliability_scheme && + ONE_SIDED_INVALID == comm->one_sided.slots_state) { + /* copy the sendbuf and seqnum to the internal temp buf in case other processes need + * to read from it */ + ucc_assert(req->length <= comm->one_sided.reliability_scheme_msg_threshold); + dest = PTR_OFFSET(comm->one_sided.slots_buffer, + (req->ag_counter % ONE_SIDED_SLOTS_COUNT) + * comm->one_sided.slot_size); + + /* both user buffer and reliablity slots are on host */ + memcpy(PTR_OFFSET(dest, ONE_SIDED_SLOTS_INFO_SIZE), req->ptr, req->length); + memcpy(dest, &req->ag_counter, ONE_SIDED_SLOTS_INFO_SIZE); + + comm->one_sided.slots_state = ONE_SIDED_VALID; + } +} + +static inline ucc_status_t ucc_tl_mlx5_mcast_do_staging_based_allgather(ucc_tl_mlx5_mcast_coll_req_t *req) +{ + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; + const int zcopy = req->proto != MCAST_PROTO_EAGER; + int num_recvd; + + ucc_assert(req->to_recv >= 0 && req->to_send >= 0); + + status = ucc_tl_mlx5_mcast_reset_reliablity(req); + if (UCC_OK != status) { + return status; + } + + if (req->to_send || req->to_recv) { + ucc_assert(comm->allgather_comm.max_push_send >= comm->pending_send); + if (req->to_send && + (comm->allgather_comm.max_push_send - comm->pending_send) > 0) { + status = ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->allgather_comm.max_push_send - + comm->pending_send, req->to_send), + zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX); + if (status < 0) { + tl_error(comm->lib, "a failure happend during send packets"); + return status; + } + } + + ucc_tl_mlx5_mcast_init_async_reliability_slots(req); + + if (req->to_recv) { + num_recvd = ucc_tl_mlx5_mcast_recv_collective(comm, req, req->to_recv, UCC_COLL_TYPE_ALLGATHER); + if (num_recvd < 0) { + tl_error(comm->lib, "a failure happend during cq polling"); + status = UCC_ERR_NO_MESSAGE; + return status; + } + } + } + + if (comm->pending_send) { + if (ucc_tl_mlx5_mcast_poll_send(comm) < 0) { + return UCC_ERR_NO_MESSAGE; + } + } + + if (comm->one_sided.reliability_enabled) { + status = ucc_tl_mlx5_mcast_check_collective(comm, req); + if (UCC_INPROGRESS != status && UCC_OK != status) { + return status; + } + } + + if (MCAST_ALLGATHER_IN_PROGRESS(req, comm)) { + return UCC_INPROGRESS; + } + + if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) { + /* mcast operations are all done, now wait until all the processes + * are done with their mcast operations */ + if (!req->barrier_req) { + // mcast operations are done and now go to barrier + status = comm->service_coll.barrier_post(comm->p2p_ctx, &req->barrier_req); + if (status != UCC_OK) { + return status; + } + tl_trace(comm->lib, "mcast operations are done and now go to barrier"); + } + + status = comm->service_coll.coll_test(req->barrier_req); + if (status == UCC_OK) { + req->barrier_req = NULL; + tl_trace(comm->lib, "barrier at the end of mcast allgather is completed"); + } else { + return status; + } + } + + /* this task is completed */ + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); +} + +void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) +{ + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle; + + ucc_assert(req != NULL); + + if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) { + /* it is not this task's turn for progress */ + ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter); + return; + } + + coll_task->status = ucc_tl_mlx5_mcast_do_staging_based_allgather(req); + if (UCC_OK != coll_task->status) { + tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", coll_task->status); + } +} + +ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) +{ + ucc_coll_task_t *coll_task = &(task->super); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + ucc_status_t status = UCC_OK; + size_t data_size = ucc_dt_size(dt) * count; + void *sbuf = args->src.info.buffer; + void *rbuf = args->dst.info.buffer; + ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm; + ucc_tl_mlx5_mcast_reg_t *reg = NULL; + ucc_rank_t max_team = ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE; + int max_ctr = ONE_SIDED_MAX_ALLGATHER_COUNTER; + ucc_tl_mlx5_mcast_coll_req_t *req; + + + task->coll_mcast.req_handle = NULL; + + tl_trace(comm->lib, "MCAST allgather init, sbuf %p, rbuf %p, size %ld, comm %d, " + "comm_size %d, counter %d", + sbuf, rbuf, data_size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter); + + req = ucc_mpool_get(&comm->ctx->mcast_req_mp); + if (!req) { + tl_error(comm->lib, "failed to get a mcast req"); + status = UCC_ERR_NO_MEMORY; + goto failed; + } + memset(req, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t)); + + req->comm = comm; + req->ptr = sbuf; + req->rptr = rbuf; + req->length = data_size; + req->mr = comm->pp_mr; + req->rreg = NULL; + /* - zero copy protocol only provides zero copy design at sender side + * - truly zero copy protocol provides zero copy design at receiver side as well + * here we select the sender side protocol */ + req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : + MCAST_PROTO_ZCOPY; + + assert(comm->commsize <= ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE); + + req->offset = 0; + req->num_packets = ucc_max(1, ucc_div_round_up(req->length, comm->max_per_packet)); + + MCAST_GET_MAX_ALLGATHER_PACKET_COUNT(comm->allgather_comm.max_num_packets, max_team, max_ctr); + + if (comm->allgather_comm.max_num_packets < req->num_packets) { + tl_warn(comm->lib, + "msg size is %ld but max supported msg size of mcast allgather is %d", + req->length, comm->allgather_comm.max_num_packets * comm->max_per_packet); + status = UCC_ERR_NOT_SUPPORTED; + goto failed; + } + + req->last_pkt_len = req->length % comm->max_per_packet; + + ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet); + + if (req->proto == MCAST_PROTO_ZCOPY) { + /* register the send buffer */ + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, ®); + if (UCC_OK != status) { + tl_error(comm->lib, "sendbuf registration failed"); + goto failed; + } + req->rreg = reg; + req->mr = reg->mr; + } + + if (comm->one_sided.reliability_enabled) { + req->one_sided_reliability_scheme = (req->length < + comm->one_sided.reliability_scheme_msg_threshold) ? + ONE_SIDED_ASYNCHRONOUS_PROTO : ONE_SIDED_SYNCHRONOUS_PROTO; + } else { + req->one_sided_reliability_scheme = ONE_SIDED_NO_RELIABILITY; + } + + req->ag_counter = comm->allgather_comm.coll_counter; + req->to_send = req->num_packets; + req->to_recv = comm->commsize * req->num_packets; + + comm->allgather_comm.coll_counter++; + + task->coll_mcast.req_handle = req; + coll_task->status = UCC_OPERATION_INITIALIZED; + task->super.post = ucc_tl_mlx5_mcast_allgather_start; + task->super.progress = ucc_tl_mlx5_mcast_allgather_progress; + return UCC_OK; + +failed: + tl_warn(UCC_TASK_LIB(task), "mcast init allgather failed:%d", status); + if (req) { + ucc_mpool_put(req); + } + return status; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h new file mode 100644 index 0000000000..a51aea451f --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h @@ -0,0 +1,15 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef UCC_TL_MLX5_MCAST_ALLGATHER_H_ +#define UCC_TL_MLX5_MCAST_ALLGATHER_H_ + +#include "tl_mlx5_mcast.h" +#include "tl_mlx5_coll.h" + +ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task); + +#endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 28c4bbce61..b6fbe84e3d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -12,12 +12,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ ucc_tl_mlx5_mcast_coll_req_t *req) { ucc_status_t status = UCC_OK; - int wsize = comm->wsize; - int num_free_win = wsize - (comm->psn - comm->last_acked); + int wsize = comm->bcast_comm.wsize; + int num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); int req_completed = (req->to_send == 0 && req->to_recv == 0); struct pp_packet *pp = NULL; - ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(comm->bcast_comm.recv_drop_packet_in_progress == false); ucc_assert(req->to_send >= 0); /* When do we need to perform reliability protocol: @@ -33,12 +33,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } - comm->n_mcast_reliable++; + comm->bcast_comm.n_mcast_reliable++; - for (;comm->last_acked < comm->psn; comm->last_acked++) { - pp = comm->r_window[comm->last_acked & (wsize-1)]; + for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { + pp = comm->r_window[comm->bcast_comm.last_acked & (wsize-1)]; ucc_assert(pp != &comm->dummy_packet); - comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet; + comm->r_window[comm->bcast_comm.last_acked & (wsize-1)] = &comm->dummy_packet; pp->context = 0; ucc_list_add_tail(&comm->bpool, &pp->super); @@ -60,7 +60,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req ucc_status_t status = UCC_OK; ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; int zcopy = req->proto != MCAST_PROTO_EAGER; - int wsize = comm->wsize; + int wsize = comm->bcast_comm.wsize; int num_free_win; int num_sent; int to_send; @@ -74,29 +74,29 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req return status; } - if (ucc_unlikely(comm->recv_drop_packet_in_progress)) { + if (ucc_unlikely(comm->bcast_comm.recv_drop_packet_in_progress)) { /* wait till parent resend the dropped packet */ return UCC_INPROGRESS; } if (req->to_send || req->to_recv) { - num_free_win = wsize - (comm->psn - comm->last_acked); + num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); /* Send data if i'm root and there is a space in the window */ if (num_free_win && req->am_root) { num_sent = req->num_packets - req->to_send; ucc_assert(req->to_send > 0); - ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize); - if (req->first_send_psn + num_sent < comm->last_acked + wsize && + ucc_assert(req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize); + if (req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize && req->to_send) { /* How many to send: either all that are left (if they fit into window) or up to the window limit */ to_send = ucc_min(req->to_send, - comm->last_acked + wsize - (req->first_send_psn + num_sent)); + comm->bcast_comm.last_acked + wsize - (req->first_send_psn + num_sent)); ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy); - num_free_win = wsize - (comm->psn - comm->last_acked); + num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); } } @@ -119,8 +119,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req tl_trace(comm->lib, "Did not receive the packet with psn in" " current window range, so get ready for drop" " event. pending_q_size %d current comm psn %d" - " last_acked psn %d stall threshold %d ", - pending_q_size, comm->psn, comm->last_acked, + " bcast_comm.last_acked psn %d stall threshold %d ", + pending_q_size, comm->psn, comm->bcast_comm.last_acked, DROP_THRESHOLD); status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req); @@ -144,7 +144,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req return status; } - if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) { + if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->bcast_comm.last_acked)) { return UCC_INPROGRESS; } else { return status; @@ -159,11 +159,6 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* req) ucc_assert(req->comm->psn >= req->start_psn); status = ucc_tl_mlx5_mcast_do_bcast(req); - if (UCC_INPROGRESS != status) { - ucc_assert(req->comm->ctx != NULL); - ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); - req->rreg = NULL; - } return status; } @@ -203,16 +198,16 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz } req->offset = 0; - req->start_psn = comm->last_psn; + req->start_psn = comm->bcast_comm.last_psn; req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1); req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet; ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet); - comm->last_psn += req->num_packets; - req->first_send_psn = req->start_psn; - req->to_send = req->am_root ? req->num_packets : 0; - req->to_recv = req->am_root ? 0 : req->num_packets; + comm->bcast_comm.last_psn += req->num_packets; + req->first_send_psn = req->start_psn; + req->to_send = req->am_root ? req->num_packets : 0; + req->to_recv = req->am_root ? 0 : req->num_packets; return UCC_OK; } @@ -229,14 +224,16 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_do_bcast(void* buf, size_t size, ucc_rank_t buf, size, root, comm->comm_id, comm->commsize, comm->rank == root, comm->psn ); - req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req"); + req = ucc_mpool_get(&comm->ctx->mcast_req_mp); if (!req) { + tl_error(comm->lib, "failed to get mcast req"); return UCC_ERR_NO_MEMORY; } + memset(req, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t)); status = ucc_tl_mlx5_mcast_prepare_bcast(buf, size, root, comm, req); if (UCC_OK != status) { - ucc_free(req); + ucc_mpool_put(req); return status; } @@ -260,10 +257,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) void *buf = args->src.info.buffer; ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm; - task->bcast_mcast.req_handle = NULL; + task->coll_mcast.req_handle = NULL; status = ucc_tl_mlx5_mcast_coll_do_bcast(buf, data_size, root, comm, - &task->bcast_mcast.req_handle); + &task->coll_mcast.req_handle); if (status < 0) { tl_error(UCC_TASK_LIB(task), "mcast_coll_do_bcast failed:%d", status); coll_task->status = status; @@ -278,7 +275,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle; if (req != NULL) { coll_task->status = ucc_tl_mlx5_mcast_test(req); @@ -308,17 +305,20 @@ ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_coll_args_t *args = &coll_args->args; int buf_size = ucc_dt_size(args->src.info.datatype) * args->src.info.count; - if (UCC_COLL_ARGS_ACTIVE_SET(args)) { - tl_trace(team->context->lib, "mcast bcast not supported for active sets"); + if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args) || + ((coll_args->args.coll_type == UCC_COLL_TYPE_ALLGATHER) && + (UCC_IS_INPLACE(coll_args->args) || UCC_IS_PERSISTENT(coll_args->args)))) { + tl_trace(team->context->lib, "mcast collective not supported"); return UCC_ERR_NOT_SUPPORTED; } if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { - tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); + tl_trace(team->context->lib, "mcast collective not compatible with this memory type"); return UCC_ERR_NOT_SUPPORTED; } if (args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA && + coll_args->args.coll_type == UCC_COLL_TYPE_BCAST && buf_size > CUDA_MEM_MCAST_BCAST_MAX_MSG) { /* for large messages (more than one mtu) we need zero-copy design which * is not implemented yet */ diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c index 0756ac142d..6c9ac92e4c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_context.c @@ -238,6 +238,7 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont device_attr.max_cq, device_attr.max_cqe); ctx->max_qp_wr = device_attr.max_qp_wr; + status = ucc_mpool_init(&ctx->compl_objects_mp, 0, sizeof(ucc_tl_mlx5_mcast_p2p_completion_obj_t), 0, UCC_CACHE_LINE_SIZE, 8, UINT_MAX, &ucc_coll_task_mpool_ops, @@ -249,6 +250,17 @@ ucc_status_t ucc_tl_mlx5_mcast_context_init(ucc_tl_mlx5_mcast_context_t *cont goto error; } + status = ucc_mpool_init(&ctx->mcast_req_mp, 0, sizeof(ucc_tl_mlx5_mcast_coll_req_t), 0, + UCC_CACHE_LINE_SIZE, 8, UINT_MAX, + &ucc_coll_task_mpool_ops, + UCC_THREAD_SINGLE, + "ucc_tl_mlx5_mcast_coll_req_t"); + if (ucc_unlikely(UCC_OK != status)) { + tl_warn(lib, "failed to initialize mcast_req_mp mpool"); + status = UCC_ERR_NO_MEMORY; + goto error; + } + ctx->rcache = NULL; status = ucc_tl_mlx5_mcast_setup_rcache(ctx); if (UCC_OK != status) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index 132cd8df78..a116c08cf8 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -663,7 +663,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, " "nacks counter %d, n_mcast_rel %d", comm->comm_id, comm->commsize, comm->psn, comm->rank, - comm->nacks_counter, comm->n_mcast_reliable); + comm->bcast_comm.nacks_counter, comm->bcast_comm.n_mcast_reliable); } if (comm->p2p_ctx != NULL) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 19dfd88097..d0b1a1ddd3 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -7,6 +7,7 @@ #ifndef TL_MLX5_MCAST_HELPER_H_ #define TL_MLX5_MCAST_HELPER_H_ #include "tl_mlx5_mcast_progress.h" +#include "tl_mlx5_mcast_one_sided_progress.h" #include "utils/ucc_math.h" #include "tl_mlx5.h" @@ -92,7 +93,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t swr[0].imm_data = htonl(pp->psn); swr[0].send_flags = (length <= comm->max_inline) ? IBV_SEND_INLINE : 0; - comm->r_window[pp->psn & (comm->wsize-1)] = pp; + comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; comm->psn++; req->to_send--; offset += length; @@ -127,13 +128,13 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t } static inline ucc_status_t ucc_tl_mlx5_mcast_process_pp(ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req, - struct pp_packet *pp, - int *num_left, int in_pending_queue) + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet *pp, + int *num_left, int in_pending_queue) { ucc_status_t status = UCC_OK; - if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->last_acked) { + if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->bcast_comm.last_acked) { /* This psn was already received */ ucc_assert(pp->context == 0); if (in_pending_queue) { @@ -253,6 +254,201 @@ static inline int ucc_tl_mlx5_mcast_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm, return num_left; } +static inline ucc_status_t ucc_tl_mlx5_mcast_send_collective(ucc_tl_mlx5_mcast_coll_comm_t* + comm, ucc_tl_mlx5_mcast_coll_req_t *req, + int num_packets, const int zcopy, + int coll_type, int group_id, size_t send_offset) +{ + struct ibv_send_wr *swr = &comm->mcast.swr; + struct ibv_sge *ssg = &comm->mcast.ssg; + int max_per_packet = comm->max_per_packet; + size_t offset = (send_offset == SIZE_MAX) ? req->offset : send_offset; + int max_commsize = ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE; + int max_ag_counter = ONE_SIDED_MAX_ALLGATHER_COUNTER; + int i; + struct ibv_send_wr *bad_wr; + struct pp_packet *pp; + int rc; + int length; + int mcast_group_index; + + ucc_assert(group_id <= comm->mcast_group_count && UCC_COLL_TYPE_ALLGATHER == coll_type); + + swr->num_sge = 1; + swr->sg_list = & comm->mcast.ssg; + swr->opcode = IBV_WR_SEND_WITH_IMM; + swr->wr.ud.remote_qpn = MULTICAST_QPN; + swr->wr.ud.remote_qkey = DEF_QKEY; + swr->next = NULL; + + for (i = 0; i < num_packets; i++) { + if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { + break; + } + ucc_assert(pp->context == 0); + + __builtin_prefetch((void*) pp->buf); + __builtin_prefetch(req->ptr + offset); + + length = req->to_send == 1 ? req->last_pkt_len : max_per_packet; + pp->length = length; + + // generate psn to be used as immediate data + /* example: encapsulate packet counter (top 16 bits), collective counter (middle 8 bits), + * and source rank (low 8 bits) - assuming max_commsize and + * max_ag_counter are 256 */ + pp->psn = (req->num_packets - req->to_send)*max_commsize*max_ag_counter + + (req->ag_counter % max_ag_counter)*max_commsize + comm->rank; + + ssg[0].addr = (uintptr_t)req->ptr + offset; + + if (!zcopy) { + memcpy((void*) pp->buf, req->ptr + offset, length); + ssg[0].addr = (uint64_t) pp->buf; + ssg[0].lkey = comm->pp_mr->lkey; + } else { + pp->context = (uintptr_t)req->ptr + offset; + ssg[0].lkey = req->mr->lkey; + } + + ssg[0].length = length; + swr[0].wr_id = (uint64_t) pp; + swr[0].imm_data = htonl(pp->psn); + swr[0].send_flags = (length <= comm->max_inline) ? IBV_SEND_INLINE : 0; + + comm->psn ++; + req->to_send --; + offset += length; + + swr[0].send_flags |= IBV_SEND_SIGNALED; + comm->pending_send++; + + if (group_id < 0) { + mcast_group_index = i % comm->mcast_group_count; // schedule sends with round-robin + } else { + mcast_group_index = group_id; + } + + swr[0].wr.ud.ah = comm->mcast.ah_list[mcast_group_index]; + + tl_trace(comm->lib, "mcast allgather post_send, psn %d, length %d, " + "zcopy %d, signaled %d qp->state %d qp->qp_num %d qp->pd %p " + "coll_type %d mcast_group_index %d", + pp->psn, pp->length, zcopy, swr[0].send_flags & + IBV_SEND_SIGNALED, + comm->mcast.qp_list[mcast_group_index]->state, + comm->mcast.qp_list[mcast_group_index]->qp_num, + comm->mcast.qp_list[mcast_group_index]->pd, coll_type, + mcast_group_index); + + if (0 != (rc = ibv_post_send(comm->mcast.qp_list[mcast_group_index], &swr[0], &bad_wr))) { + tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " + "to_recv %d, length %d, psn %d, inline %d", + rc, req->start_psn, req->to_send, req->to_recv, + length, pp->psn, length <= comm->max_inline); + return UCC_ERR_NO_MESSAGE; + } + } + + if (send_offset == SIZE_MAX) { + req->offset = offset; + } + + return UCC_OK; +} + +static inline int ucc_tl_mlx5_mcast_recv_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, int + num_left, int coll_type) +{ + int max_commsize = ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE; + int max_ag_counter = ONE_SIDED_MAX_ALLGATHER_COUNTER; + struct pp_packet *pp; + struct pp_packet *next; + uint64_t id; + struct ibv_wc *wc; + int num_comp; + int i; + int real_num_comp; + int recv_progressed = 0; + int ag_counter; + ucc_status_t status; + + /* check if we have already received something */ + ucc_list_for_each_safe(pp, next, &comm->pending_q, super) { + ag_counter = (pp->psn / max_commsize) % + max_ag_counter; + if (ag_counter == (req->ag_counter % max_ag_counter)) { + ucc_list_del(&pp->super); + status = ucc_tl_mlx5_mcast_process_packet_collective(comm, req, pp, coll_type); + if (UCC_OK != status) { + tl_error(comm->lib, "process allgather packet failed, status %d", + status); + return -1; + } + recv_progressed++; + } + }; + + wc = ucc_malloc(sizeof(struct ibv_wc) * POLL_PACKED, "WC"); + if (!wc) { + return -1; + } + + while (num_left > recv_progressed) + { + memset(wc, 0, sizeof(sizeof(struct ibv_wc) * POLL_PACKED)); + num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, &wc[0]); + + if (num_comp < 0) { + tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); + ucc_free(wc); + return -1; + } else if (num_comp == 0) { + break; + } + + if (IBV_WC_SUCCESS != wc[0].status) { + fprintf(stderr, "mcast_recv: %s err pending_recv %d wr_id %ld num_comp %d byte_len %d\n", + ibv_wc_status_str(wc[0].status), comm->pending_recv, wc[0].wr_id, num_comp, wc[0].byte_len); + return -1; + } + + real_num_comp = num_comp; + + for (i = 0; i < real_num_comp; i++) { + ucc_assert(wc[i].status == IBV_WC_SUCCESS); + id = wc[i].wr_id; + pp = (struct pp_packet*) (id); + pp->length = wc[i].byte_len - GRH_LENGTH; + pp->psn = ntohl(wc[i].imm_data); + + tl_trace(comm->lib, "%d collective pkt completion: psn %d, length %d, " + "req_num packets %d, to_send %d, to_recv %d, num_left %d \n", + coll_type, pp->psn, pp->length, req->num_packets, + req->to_send, req->to_recv, num_left); + + status = ucc_tl_mlx5_mcast_process_packet_collective(comm, req, pp, coll_type); + if (UCC_OK != status) { + tl_error(comm->lib, "process allgather packet failed, status %d", + status); + ucc_free(wc); + return -1; + } + + recv_progressed++; + ucc_assert(pp->qp_id < MAX_GROUP_COUNT); + } + + comm->pending_recv -= num_comp; + ucc_tl_mlx5_mcast_post_recv_buffers(comm); + } + + ucc_free(wc); + return recv_progressed; +} + + static inline ucc_status_t ucc_tl_mlx5_mcast_poll_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm) { ucc_status_t status = UCC_OK; @@ -310,8 +506,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com { ucc_status_t status = UCC_OK; - if (comm->racks_n != comm->child_n || comm->sacks_n != comm->parent_n || - comm->nack_requests) { + if (comm->bcast_comm.racks_n != comm->bcast_comm.child_n || + comm->bcast_comm.sacks_n != comm->bcast_comm.parent_n || + comm->bcast_comm.nack_requests) { if (comm->pending_send) { status = ucc_tl_mlx5_mcast_poll_send(comm); if (UCC_OK != status) { @@ -319,7 +516,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - if (comm->parent_n) { + if (comm->bcast_comm.parent_n) { status = ucc_tl_mlx5_mcast_poll_recv(comm); if (UCC_OK != status) { return status; @@ -332,26 +529,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - if (comm->parent_n && !comm->reliable_in_progress) { + if (comm->bcast_comm.parent_n && !comm->bcast_comm.reliable_in_progress) { status = ucc_tl_mlx5_mcast_reliable_send(comm); if (UCC_OK != status) { return status; } } - if (!comm->reliable_in_progress) { - comm->reliable_in_progress = 1; + if (!comm->bcast_comm.reliable_in_progress) { + comm->bcast_comm.reliable_in_progress = 1; } - if (comm->racks_n == comm->child_n && comm->sacks_n == comm->parent_n && - 0 == comm->nack_requests) { + if (comm->bcast_comm.racks_n == comm->bcast_comm.child_n && + comm->bcast_comm.sacks_n == comm->bcast_comm.parent_n && 0 == + comm->bcast_comm.nack_requests) { // Reset for next round. - memset(comm->parents, 0, sizeof(comm->parents)); - memset(comm->children, 0, sizeof(comm->children)); + memset(comm->bcast_comm.parents, 0, sizeof(comm->bcast_comm.parents)); + memset(comm->bcast_comm.children, 0, sizeof(comm->bcast_comm.children)); - comm->racks_n = comm->child_n = 0; - comm->sacks_n = comm->parent_n = 0; - comm->reliable_in_progress = 0; + comm->bcast_comm.racks_n = comm->bcast_comm.child_n = 0; + comm->bcast_comm.sacks_n = comm->bcast_comm.parent_n = 0; + comm->bcast_comm.reliable_in_progress = 0; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c new file mode 100644 index 0000000000..208e9991c0 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c @@ -0,0 +1,273 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "tl_mlx5_mcast_one_sided_progress.h" +#include +#include "tl_mlx5_mcast_rcache.h" + +ucc_status_t ucc_tl_mlx5_mcast_reliable_one_sided_get(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + int *completed) +{ + int target_completed = 0; + int issued = 0; + ucc_status_t status; + ucc_rank_t target; + ucc_tl_mlx5_mcast_reg_t *reg; + void *src_addr; + void *remote_addr; + uint32_t rkey; + uint32_t lkey; + size_t size; + uint64_t wr; + + /* in sync design this function is only called once */ + ucc_assert(!(ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme && + comm->one_sided.rdma_read_in_progress)); + + for (target = 0; target < comm->commsize; target++) { + if (comm->one_sided.recvd_pkts_tracker[target] != req->num_packets) { + tl_trace(comm->lib, "%d of the packets from source target %d are dropped", + req->num_packets - comm->one_sided.recvd_pkts_tracker[target], target); + + // register the recv buf if it is not already registered + + if (NULL == req->recv_rreg) { + tl_debug(comm->lib, "registering recv buf of size %d", comm->commsize * req->length); + + status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, comm->commsize * req->length, ®); + if (UCC_OK != status) { + return status; + } + + req->recv_rreg = reg; + req->recv_mr = reg->mr; + } + + switch(req->one_sided_reliability_scheme) { + case ONE_SIDED_ASYNCHRONOUS_PROTO: + /* first check if the remote slot is valid */ + /* in this design, if reliability protocol is kicked, allgather is + * completed once all the values in one_sided.recvd_pkts_tracker[] is set to req->num_packets + * and comm->pending_reads is set to 0 */ + if (req->ag_counter == comm->one_sided.remote_slot_info[target]) { + /* read remote data from remote slot + * the content of this data is copied from send buffer by remote + * process */ + src_addr = PTR_OFFSET(req->rptr, (req->length * target)); + remote_addr = PTR_OFFSET(comm->one_sided_async_slots_info_list[target].remote_addr, + ((req->ag_counter % ONE_SIDED_SLOTS_COUNT) * comm->one_sided.slot_size + + ONE_SIDED_SLOTS_INFO_SIZE)); + lkey = req->recv_mr->lkey; + rkey = comm->one_sided_async_slots_info_list[target].rkey; + size = req->length; + wr = MCAST_AG_RDMA_READ_WR; + + comm->pending_reads++; + target_completed++; + comm->one_sided.remote_slot_info[target] = ONE_SIDED_PENDING_DATA; + comm->one_sided.recvd_pkts_tracker[target] = req->num_packets; + + } else if (ONE_SIDED_PENDING_INFO != comm->one_sided.remote_slot_info[target] && + ONE_SIDED_PENDING_DATA != comm->one_sided.remote_slot_info[target]) { + /* remote slot is not valid yet. Need to do an rdma + * read to check the latest state */ + src_addr = &comm->one_sided.remote_slot_info[target]; + remote_addr = PTR_OFFSET(comm->one_sided_async_slots_info_list[target].remote_addr, + ((req->ag_counter % ONE_SIDED_SLOTS_COUNT) * comm->one_sided.slot_size)); + lkey = comm->one_sided.remote_slot_info_mr->lkey; + rkey = comm->one_sided_async_slots_info_list[target].rkey; + size = ONE_SIDED_SLOTS_INFO_SIZE; + wr = MCAST_AG_RDMA_READ_INFO_WR; + + comm->one_sided.remote_slot_info[target] = ONE_SIDED_PENDING_INFO; + + } else { + /* rdma read to remote info or data has already been issue but it + * has not been completed */ + continue; + } + break; + + case ONE_SIDED_SYNCHRONOUS_PROTO: + /* read the whole remote send buffer */ + src_addr = PTR_OFFSET(req->rptr, (req->length * target)); + remote_addr = comm->ag_info_list[target].remote_addr; + rkey = comm->ag_info_list[target].rkey; + lkey = req->recv_mr->lkey; + size = req->length; + wr = MCAST_AG_RDMA_READ_WR; + + comm->pending_reads++; + target_completed++; + break; + + default: + return UCC_ERR_NOT_IMPLEMENTED; + } + + issued++; + status = ucc_tl_one_sided_p2p_get(src_addr, remote_addr, size, lkey, rkey, target, wr, comm); + if (UCC_OK != status) { + return status; + } + } else { + /* all the expected packets from this target have arrived */ + target_completed++; + } + } + + comm->one_sided.rdma_read_in_progress = 1; + + if (completed) { + *completed = target_completed; + } + + if (issued) { + tl_debug(comm->lib, "issued %d RDMA READ to remote INFO/DATA. Number of target ranks completed: %d", + issued, target_completed); + } + + return UCC_OK; +} + +ucc_status_t ucc_tl_mlx5_mcast_progress_one_sided_communication(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req) +{ + int completed = 0; + ucc_status_t status; + + ucc_assert(comm->one_sided.rdma_read_in_progress); + + if (!req->to_send && !req->to_recv) { + // need to wait until all the rdma reads are done to avoid data invalidation + tl_trace(comm->lib, + "All the mcast packets arrived during the reliablity protocol. Current timeout is %d usec", + comm->ctx->params.timeout); + } + + if (ucc_tl_mlx5_mcast_poll_send(comm) < 0) { + return UCC_ERR_NO_MESSAGE; + } + + // check if all the rdma read have been completed and return UCC_OK if so + switch(req->one_sided_reliability_scheme) { + case ONE_SIDED_ASYNCHRONOUS_PROTO: + status = ucc_tl_mlx5_mcast_reliable_one_sided_get(comm, req, &completed); + if (UCC_OK != status) { + return status; + } + + if (!comm->pending_reads && (completed == comm->commsize)) { + tl_debug(comm->lib, "All the pending RDMA READ are comepleted in async reliablity protocol"); + comm->one_sided.rdma_read_in_progress = 0; + req->to_recv = 0; + return UCC_OK; + } + break; + + case ONE_SIDED_SYNCHRONOUS_PROTO: + if (!comm->pending_reads) { + tl_debug(comm->lib, "All the pending RDMA READ are comepleted in sync reliablity protocol"); + comm->one_sided.rdma_read_in_progress = 0; + req->to_recv = 0; + return UCC_OK; + } + break; + + default: + return UCC_ERR_NOT_IMPLEMENTED; + } + + return UCC_INPROGRESS; +} + +ucc_status_t ucc_tl_mlx5_mcast_process_packet_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet *pp, + int coll_type) +{ + int out_of_order_recvd = 0; + void *dest; + int offset; + int source_rank; + uint32_t ag_counter; + ucc_status_t status; + int count; + ucc_rank_t target; + ucc_tl_mlx5_mcast_reg_t *reg; + void *src_addr; + void *remote_addr; + uint32_t rkey; + uint32_t lkey; + size_t size; + uint64_t wr; + + ucc_assert(pp->context == 0); // making sure it's a recv packet not send + ucc_assert(UCC_COLL_TYPE_ALLGATHER == coll_type); + + // process the immediate value saved in pp->psn + source_rank = pp->psn % ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE; + ag_counter = (pp->psn / ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) % ONE_SIDED_MAX_ALLGATHER_COUNTER; + offset = (pp->psn / (ONE_SIDED_MAX_ALLGATHER_COUNTER * ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE)); + + tl_trace(comm->lib, "processing a recvd packet with length %d source_rank" + " %d ag_counter %d offset %d", pp->length, source_rank, + ag_counter, offset); + + ucc_assert(offset < req->num_packets); + // there are scenarios where we receive a packet with same offset/rank more than one time + // this means that a packet which was considered dropped in previous run has not just arrived + // need to check the allgather call counter and ignore this packet if it does not match + + if (ag_counter == (req->ag_counter % ONE_SIDED_MAX_ALLGATHER_COUNTER)) { + if (pp->length) { + if (pp->length == comm->max_per_packet) { + dest = PTR_OFFSET(req->rptr, (offset * pp->length + source_rank * req->length)); + } else { + dest = PTR_OFFSET(req->rptr, ((req->length - pp->length) + source_rank * req->length); + } + memcpy(dest, (void*) pp->buf, pp->length); + } + + if (comm->one_sided.reliability_enabled) { + /* out of order recv'd packet that happen that is fatal in zero-copy + * design is considered just like dropped packet */ + if (out_of_order_recvd == 0) { + comm->one_sided.recvd_pkts_tracker[source_rank]++; + } + + if (comm->one_sided.recvd_pkts_tracker[source_rank] > req->num_packets) { + tl_error(comm->lib, "reliablity failed: comm->one_sided.recvd_pkts_tracker[%d] %d" + " req->num_packets %d offset %d" + " comm->allgather_comm.under_progress_counter %d req->ag_counter" + " %d \n", source_rank, comm->one_sided.recvd_pkts_tracker[source_rank], + req->num_packets, offset, + comm->allgather_comm.under_progress_counter, req->ag_counter); + return UCC_ERR_NO_MESSAGE; + } + } + req->to_recv--; + comm->psn++; + pp->context = 0; + ucc_list_add_tail(&comm->bpool, &pp->super); + comm->pending_recv_per_qp[pp->qp_id]--; + } else if (ag_counter > (req->ag_counter % ONE_SIDED_MAX_ALLGATHER_COUNTER)) { + /* received out of order allgather packet - add it to queue for future + * processing */ + ucc_list_add_tail(&comm->pending_q, &pp->super); + } else { + /* received a packet which was left from previous iterations + * it is due to the fact that reliablity protocol was initiated. + * return the posted receive buffer back to the pool */ + ucc_assert(comm->one_sided.reliability_enabled); + pp->context = 0; + ucc_list_add_tail(&comm->bpool, &pp->super); + } + + return UCC_OK; +} + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.h new file mode 100644 index 0000000000..e8a5e84e35 --- /dev/null +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include +#include +#include +#include "tl_mlx5_mcast.h" +#include "tl_mlx5_mcast_helper.h" +#include "p2p/ucc_tl_mlx5_mcast_p2p.h" + +#ifndef TL_MLX5_MCAST_ONE_SIDED_PROGRESS_H_ +#define TL_MLX5_MCAST_ONE_SIDED_PROGRESS_H_ + +ucc_status_t ucc_tl_mlx5_mcast_progress_one_sided_communication(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req); + +ucc_status_t ucc_tl_mlx5_mcast_reliable_one_sided_get(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + int *completed); + +ucc_status_t ucc_tl_mlx5_mcast_process_packet_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm, + ucc_tl_mlx5_mcast_coll_req_t *req, + struct pp_packet* pp, int coll_type); + +#endif + diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c index 85d63a82d0..a3d16fd6d8 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c @@ -169,6 +169,13 @@ ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_init(ucc_base_team_t *team) ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; ucc_status_t status = UCC_OK; + if (comm->commsize > ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) { + tl_warn(comm->lib, + "team size is %d but max supported team size of mcast one-sided reliability is %d", + comm->commsize, ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE); + return UCC_ERR_NOT_SUPPORTED; + } + status = ucc_tl_mlx5_mcast_one_sided_setup_reliability_buffers(team); if (status != UCC_OK) { tl_error(comm->lib, "setup reliablity resources failed"); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 41b6ca14f9..3620cf629f 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -24,13 +24,12 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc if (pkt_id != UINT_MAX) { /* we sent the real data to our child so reduce the nack reqs */ - ucc_assert(comm->nack_requests > 0); - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING); - comm->p2p_pkt[pkt_id].type = MCAST_P2P_ACK; - comm->nack_requests--; - status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id], - sizeof(struct packet), comm->p2p_pkt[pkt_id].from, - UCC_MEMORY_TYPE_HOST, + ucc_assert(comm->bcast_comm.nack_requests > 0); + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING); + comm->bcast_comm.p2p_pkt[pkt_id].type = MCAST_P2P_ACK; + comm->bcast_comm.nack_requests--; + status = comm->params.p2p_iface.recv_nb(&comm->bcast_comm.p2p_pkt[pkt_id], + sizeof(struct packet), comm->bcast_comm.p2p_pkt[pkt_id].from, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL)); if (status < 0) { @@ -46,19 +45,19 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, int p2p_pkt_id) { - uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn; - struct pp_packet *pp = comm->r_window[psn % comm->wsize]; + uint32_t psn = comm->bcast_comm.p2p_pkt[p2p_pkt_id].psn; + struct pp_packet *pp = comm->r_window[psn % comm->bcast_comm.wsize]; ucc_status_t status; ucc_memory_type_t mem_type; ucc_assert(pp->psn == psn); - ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); + ucc_assert(comm->bcast_comm.p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); - comm->p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING; + comm->bcast_comm.p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING; tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld nack_requests %d \n", comm->comm_id, comm->rank, - comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests); + comm->bcast_comm.p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->bcast_comm.nack_requests); if (comm->cuda_mem_enabled) { mem_type = UCC_MEMORY_TYPE_CUDA; @@ -67,7 +66,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_ } status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), - pp->length, comm->p2p_pkt[p2p_pkt_id].from, mem_type, + pp->length, comm->bcast_comm.p2p_pkt[p2p_pkt_id].from, mem_type, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id)); if (status < 0) { @@ -83,14 +82,14 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t int i; struct pp_packet *pp; - if (!comm->nack_requests) { + if (!comm->bcast_comm.nack_requests) { return UCC_OK; } if (psn != UINT32_MAX) { - for (i=0; ichild_n; i++) { - if (psn == comm->p2p_pkt[i].psn && - comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + for (i=0; ibcast_comm.child_n; i++) { + if (psn == comm->bcast_comm.p2p_pkt[i].psn && + comm->bcast_comm.p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); if (status != UCC_OK) { break; @@ -98,10 +97,10 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t } } } else { - for (i=0; ichild_n; i++){ - if (comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { - psn = comm->p2p_pkt[i].psn; - pp = comm->r_window[psn % comm->wsize]; + for (i=0; ibcast_comm.child_n; i++){ + if (comm->bcast_comm.p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + psn = comm->bcast_comm.p2p_pkt[i].psn; + pp = comm->r_window[psn % comm->bcast_comm.wsize]; if (psn == pp->psn) { status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); if (status < 0) { @@ -118,9 +117,9 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t static inline int ucc_tl_mlx5_mcast_find_nack_psn(ucc_tl_mlx5_mcast_coll_comm_t* comm, ucc_tl_mlx5_mcast_coll_req_t *req) { - int psn = ucc_max(comm->last_acked, req->start_psn); + int psn = ucc_max(comm->bcast_comm.last_acked, req->start_psn); int max_search_psn = ucc_min(req->start_psn + req->num_packets, - comm->last_acked + comm->wsize + 1); + comm->bcast_comm.last_acked + comm->bcast_comm.wsize + 1); for (; psn < max_search_psn; psn++) { if (!PSN_RECEIVED(psn, comm)) { @@ -166,7 +165,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p } req->to_recv--; - comm->r_window[pp->psn % comm->wsize] = pp; + comm->r_window[pp->psn % comm->bcast_comm.wsize] = pp; status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); if (status < 0) { @@ -174,7 +173,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p } comm->psn++; - comm->recv_drop_packet_in_progress = false; + comm->bcast_comm.recv_drop_packet_in_progress = false; return status; } @@ -197,7 +196,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas parent = ucc_tl_mlx5_mcast_get_nack_parent(req); - comm->nacks_counter++; + comm->bcast_comm.nacks_counter++; status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, @@ -214,7 +213,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas pp->psn = psn; pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm); - comm->recv_drop_packet_in_progress = true; + comm->bcast_comm.recv_drop_packet_in_progress = true; if (comm->cuda_mem_enabled) { mem_type = UCC_MEMORY_TYPE_CUDA; @@ -240,20 +239,20 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm ucc_status_t status; tl_trace(comm->lib, "comm %p, psn %d, last_acked %d, n_parent %d", - comm, comm->psn, comm->last_acked, comm->parent_n); + comm, comm->psn, comm->bcast_comm.last_acked, comm->bcast_comm.parent_n); - ucc_assert(!comm->reliable_in_progress); + ucc_assert(!comm->bcast_comm.reliable_in_progress); - for (i=0; iparent_n; i++) { - parent = comm->parents[i]; - comm->p2p_spkt[i].type = MCAST_P2P_ACK; - comm->p2p_spkt[i].psn = comm->last_acked + comm->wsize; - comm->p2p_spkt[i].comm_id = comm->comm_id; + for (i=0; ibcast_comm.parent_n; i++) { + parent = comm->bcast_comm.parents[i]; + comm->bcast_comm.p2p_spkt[i].type = MCAST_P2P_ACK; + comm->bcast_comm.p2p_spkt[i].psn = comm->bcast_comm.last_acked + comm->bcast_comm.wsize; + comm->bcast_comm.p2p_spkt[i].comm_id = comm->comm_id; tl_trace(comm->lib, "rank %d, Posting SEND to parent %d, n_parent %d, psn %d", - comm->rank, parent, comm->parent_n, comm->psn); + comm->rank, parent, comm->bcast_comm.parent_n, comm->psn); - status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i], + status = comm->params.p2p_iface.send_nb(&comm->bcast_comm.p2p_spkt[i], sizeof(struct packet), parent, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_send_completion, i, NULL)); @@ -273,19 +272,19 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp struct pp_packet *pp; ucc_status_t status; - ucc_assert(comm->comm_id == comm->p2p_pkt[pkt_id].comm_id); + ucc_assert(comm->comm_id == comm->bcast_comm.p2p_pkt[pkt_id].comm_id); - if (comm->p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK); - psn = comm->p2p_pkt[pkt_id].psn; - pp = comm->r_window[psn % comm->wsize]; + if (comm->bcast_comm.p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_NACK); + psn = comm->bcast_comm.p2p_pkt[pkt_id].psn; + pp = comm->r_window[psn % comm->bcast_comm.wsize]; tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d pkt_id %d", comm->comm_id, comm->rank, - comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id); + comm->bcast_comm.p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id); - comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; - comm->nack_requests++; + comm->bcast_comm.p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; + comm->bcast_comm.nack_requests++; if (pp->psn == psn) { /* parent already has this packet so it is ready to forward it to its child */ @@ -296,8 +295,8 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp } } else { - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_ACK); - comm->racks_n++; + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_ACK); + comm->bcast_comm.racks_n++; } ucc_mpool_put(obj); /* return the completion object back to the mem pool compl_objects_mp */ @@ -309,7 +308,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_comp { ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; - comm->sacks_n++; + comm->bcast_comm.sacks_n++; ucc_mpool_put(obj); return UCC_OK; } @@ -343,20 +342,20 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c while (mask < comm->commsize) { if (vrank & mask) { req->parent = TO_ORIGINAL((vrank ^ mask), comm->commsize, root); - add_uniq(comm->parents, &comm->parent_n, req->parent); + add_uniq(comm->bcast_comm.parents, &comm->bcast_comm.parent_n, req->parent); break; } else { child = vrank ^ mask; if (child < comm->commsize) { child = TO_ORIGINAL(child, comm->commsize, root); - if (add_uniq(comm->children, &comm->child_n, child)) { + if (add_uniq(comm->bcast_comm.children, &comm->bcast_comm.child_n, child)) { tl_trace(comm->lib, "rank %d, Posting RECV from child %d, n_child %d, psn %d", - comm->rank, child, comm->child_n, comm->psn); + comm->rank, child, comm->bcast_comm.child_n, comm->psn); - status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1], + status = comm->params.p2p_iface.recv_nb(&comm->bcast_comm.p2p_pkt[comm->bcast_comm.child_n - 1], sizeof(struct packet), child, UCC_MEMORY_TYPE_HOST, comm->p2p_ctx, GET_COMPL_OBJ(comm, - ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req)); + ucc_tl_mlx5_mcast_recv_completion, comm->bcast_comm.child_n - 1, req)); if (status < 0) { return status; } @@ -370,12 +369,6 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c return UCC_OK; } -static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) -{ - double t_second = ucc_get_time(); - return (uint64_t) (t_second * 1000000); -} - ucc_status_t ucc_tl_mlx5_mcast_bcast_check_drop(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req) { @@ -424,7 +417,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com } } - comm->r_window[pp->psn & (comm->wsize-1)] = pp; + comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); if (status < 0) { return status; @@ -432,7 +425,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com req->to_recv--; comm->psn++; - ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(comm->bcast_comm.recv_drop_packet_in_progress == false); return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 52bc242ae5..b70ca6e2f6 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -99,11 +99,14 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); - comm->wsize = conf_params->wsize; - comm->max_eager = conf_params->max_eager; - comm->cuda_mem_enabled = conf_params->cuda_mem_enabled; - comm->comm_id = team_params->id; - comm->ctx = mcast_context; + comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable; + comm->bcast_comm.wsize = conf_params->wsize; + comm->allgather_comm.max_push_send = conf_params->max_push_send; + comm->max_eager = conf_params->max_eager; + comm->cuda_mem_enabled = conf_params->cuda_mem_enabled; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; + comm->mcast_group_count = 1; /* TODO: add support for more number of mcast groups */ if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) { tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready"); @@ -130,20 +133,20 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->rank = team_params->rank; - comm->commsize = team_params->size; - comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; - comm->last_acked = comm->last_psn = 0; - comm->racks_n = comm->sacks_n = 0; - comm->child_n = comm->parent_n = 0; - comm->p2p_ctx = conf_params->oob; + comm->rank = team_params->rank; + comm->commsize = team_params->size; + comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; + comm->bcast_comm.last_acked = comm->bcast_comm.last_psn = 0; + comm->bcast_comm.racks_n = comm->bcast_comm.sacks_n = 0; + comm->bcast_comm.child_n = comm->bcast_comm.parent_n = 0; + comm->p2p_ctx = conf_params->oob; memcpy(&comm->p2p, &conf_params->p2p_iface, sizeof(ucc_tl_mlx5_mcast_p2p_interface_t)); comm->dummy_packet.psn = UINT32_MAX; - for (i=0; i< comm->wsize; i++) { + for (i=0; i< comm->bcast_comm.wsize; i++) { comm->r_window[i] = &comm->dummy_packet; } @@ -284,15 +287,14 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ goto error; } - memset(comm->parents, 0, sizeof(comm->parents)); - memset(comm->children, 0, sizeof(comm->children)); + memset(comm->bcast_comm.parents, 0, sizeof(comm->bcast_comm.parents)); + memset(comm->bcast_comm.children, 0, sizeof(comm->bcast_comm.children)); - comm->nacks_counter = 0; - comm->tx = 0; - comm->n_prep_reliable = 0; - comm->n_mcast_reliable = 0; - comm->reliable_in_progress = 0; - comm->recv_drop_packet_in_progress = 0; + comm->bcast_comm.nacks_counter = 0; + comm->bcast_comm.n_mcast_reliable = 0; + comm->bcast_comm.reliable_in_progress = 0; + comm->bcast_comm.recv_drop_packet_in_progress = 0; + comm->tx = 0; return status; diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 7bf8572aab..5cdd6c51a1 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -88,6 +88,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.wsize), UCC_CONFIG_TYPE_INT}, + {"MCAST_MAX_PUSH_SEND", "16", "Max number of concurrent send wq for mcast based allgather", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_push_send), + UCC_CONFIG_TYPE_INT}, + {"MCAST_MAX_EAGER", "65536", "Max msg size to be used for Mcast with the eager protocol", ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager), UCC_CONFIG_TYPE_MEMUNITS}, @@ -96,6 +100,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled), UCC_CONFIG_TYPE_BOOL}, + {"MCAST_ONE_SIDED_RELIABILITY_ENABLE", "1", "Enable one sided reliability for mcast", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), + UCC_CONFIG_TYPE_BOOL}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 909b457325..94d336ba6e 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -6,11 +6,13 @@ #include "tl_mlx5_coll.h" #include "mcast/tl_mlx5_mcast_coll.h" +#include "mcast/tl_mlx5_mcast_allgather.h" +#include "mcast/tl_mlx5_mcast_rcache.h" #include "alltoall/alltoall.h" -ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) +ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) { ucc_status_t status = UCC_OK; ucc_tl_mlx5_task_t *task = NULL; @@ -27,14 +29,28 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, task->super.finalize = ucc_tl_mlx5_task_finalize; - status = ucc_tl_mlx5_mcast_bcast_init(task); - if (ucc_unlikely(UCC_OK != status)) { + switch (coll_args->args.coll_type) { + case UCC_COLL_TYPE_BCAST: + status = ucc_tl_mlx5_mcast_bcast_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + break; + case UCC_COLL_TYPE_ALLGATHER: + status = ucc_tl_mlx5_mcast_allgather_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + break; + default: + status = UCC_ERR_NOT_SUPPORTED; + tl_trace(team->context->lib, "mcast not supported for this collective type"); goto free_task; } *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "init coll task %p", task); + tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK; @@ -46,13 +62,35 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task) { ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_tl_mlx5_mcast_coll_req_t *req = task->bcast_mcast.req_handle; + ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle; if (req != NULL) { ucc_assert(coll_task->status != UCC_INPROGRESS); - ucc_free(req); + ucc_assert(req->comm->ctx != NULL); + if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_ALLGATHER && + coll_task->status == UCC_OK) { + req->comm->allgather_comm.under_progress_counter++; + /* reset the reliability structures so that it gets initialized again for the next + * allgather */ + req->comm->one_sided.reliability_ready = 0; + req->comm->stalled = 0; + req->comm->timer = 0; + } + if (req->rreg != NULL) { + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); + req->rreg = NULL; + } + if (req->recv_rreg != NULL) { + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg); + req->recv_rreg = NULL; + } + if (req->ag_schedule) { + ucc_free(req->ag_schedule); + req->ag_schedule = NULL; + } + ucc_mpool_put(req); tl_trace(UCC_TASK_LIB(task), "finalizing an mcast task %p", task); - task->bcast_mcast.req_handle = NULL; + task->coll_mcast.req_handle = NULL; } tl_trace(UCC_TASK_LIB(task), "finalizing task %p", task); @@ -82,7 +120,8 @@ ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, status = ucc_tl_mlx5_alltoall_init(coll_args, team, task_h); break; case UCC_COLL_TYPE_BCAST: - status = ucc_tl_mlx5_bcast_mcast_init(coll_args, team, task_h); + case UCC_COLL_TYPE_ALLGATHER: + status = ucc_tl_mlx5_coll_mcast_init(coll_args, team, task_h); break; default: status = UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index eb441bdcdf..8ffe3eaf64 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -16,7 +16,7 @@ typedef struct ucc_tl_mlx5_task { union { struct { ucc_tl_mlx5_mcast_coll_req_t *req_handle; - } bcast_mcast; + } coll_mcast; }; } ucc_tl_mlx5_task_t; @@ -79,7 +79,7 @@ ucc_tl_mlx5_get_task(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team) UCC_TL_MLX5_PROFILE_REQUEST_NEW(task, "tl_mlx5_task", 0); ucc_coll_task_init(&task->super, coll_args, team); - task->bcast_mcast.req_handle = NULL; + task->coll_mcast.req_handle = NULL; return task; } @@ -113,9 +113,9 @@ static inline void ucc_tl_mlx5_put_schedule(ucc_tl_mlx5_schedule_t *schedule) ucc_mpool_put(schedule); } -ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h); +ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task); diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index 7631278ad3..d0539e83c9 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_dm.c b/src/components/tl/mlx5/tl_mlx5_dm.c index 541273ff7a..2a0c474a39 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.c +++ b/src/components/tl/mlx5/tl_mlx5_dm.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_dm.h b/src/components/tl/mlx5/tl_mlx5_dm.h index 3b611e44b3..05738bf539 100644 --- a/src/components/tl/mlx5/tl_mlx5_dm.h +++ b/src/components/tl/mlx5/tl_mlx5_dm.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_lib.c b/src/components/tl/mlx5/tl_mlx5_lib.c index 11829f066f..509af869a4 100644 --- a/src/components/tl/mlx5/tl_mlx5_lib.c +++ b/src/components/tl/mlx5/tl_mlx5_lib.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_pd.c b/src/components/tl/mlx5/tl_mlx5_pd.c index bf98352883..551a945169 100644 --- a/src/components/tl/mlx5/tl_mlx5_pd.c +++ b/src/components/tl/mlx5/tl_mlx5_pd.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_pd.h b/src/components/tl/mlx5/tl_mlx5_pd.h index 2462cf1383..9ea925781a 100644 --- a/src/components/tl/mlx5/tl_mlx5_pd.h +++ b/src/components/tl/mlx5/tl_mlx5_pd.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 614ead348b..1e5f6ddf56 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index e8e36a27f6..cf4d590658 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.h b/src/components/tl/mlx5/tl_mlx5_wqe.h index a0a015c310..012dc74fc4 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.h +++ b/src/components/tl/mlx5/tl_mlx5_wqe.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */