diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 1208226bda..7843db514f 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -152,6 +152,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context { ucc_rcache_t *rcache; ucc_tl_mlx5_mcast_ctx_params_t params; ucc_base_lib_t *lib; + enum ucc_tl_capabilities tl_caps; } ucc_tl_mlx5_mcast_coll_context_t; typedef struct ucc_tl_mlx5_mcast_join_info_t { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 402ff84472..7d1db0f64b 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -68,6 +68,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, conf_params->rx_sge = 2; conf_params->scq_moderation = 64; + mcast_context->tl_caps = base_context->ucc_context->tl_caps; + comm = (ucc_tl_mlx5_mcast_coll_comm_t*) ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) + sizeof(struct pp_packet*)*(conf_params->wsize-1), diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 1c7c49b53f..96673596f5 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -194,6 +194,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, self); self->ucp_memory_types = context_attr.memory_types; + if (self->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[UCC_MEMORY_TYPE_CUDA])) { + /* TL MLX5 needs this information */ + self->super.super.ucc_context->tl_caps |= UCC_TL_UCP_CUDA_ENABLED; + } + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; switch (params->thread_mode) { case UCC_THREAD_SINGLE: diff --git a/src/core/ucc_context.h b/src/core/ucc_context.h index 3944d5675b..43c55093f7 100644 --- a/src/core/ucc_context.h +++ b/src/core/ucc_context.h @@ -35,6 +35,12 @@ typedef struct ucc_context_id { #define UCC_CTX_ID_EQUAL(_id1, _id2) (UCC_PROC_INFO_EQUAL((_id1).pi, (_id2).pi) \ && (_id1).seq_num == (_id2).seq_num) +enum ucc_tl_capabilities { + /* capabalities that every TL needs to be aware of + * about other TLs */ + UCC_TL_UCP_CUDA_ENABLED = UCC_BIT(0) +}; + enum { /* all ranks have identical set of TLs*/ UCC_ADDR_STORAGE_FLAG_TLS_SYMMETRIC = UCC_BIT(0), @@ -78,6 +84,7 @@ typedef struct ucc_context { uint64_t cl_flags; ucc_tl_team_t *service_team; int32_t throttle_progress; + enum ucc_tl_capabilities tl_caps; } ucc_context_t; typedef struct ucc_context_config {