Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCC/CTX: passing cuda-check from tl ucp to mlx5 #1013

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_context {
ucc_rcache_t *rcache;
ucc_tl_mlx5_mcast_ctx_params_t params;
ucc_base_lib_t *lib;
enum ucc_tl_capabilities tl_caps;
MamziB marked this conversation as resolved.
Show resolved Hide resolved
} ucc_tl_mlx5_mcast_coll_context_t;

typedef struct ucc_tl_mlx5_mcast_join_info_t {
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
conf_params->rx_sge = 2;
conf_params->scq_moderation = 64;

mcast_context->tl_caps = base_context->ucc_context->tl_caps;

comm = (ucc_tl_mlx5_mcast_coll_comm_t*)
ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_comm_t) +
sizeof(struct pp_packet*)*(conf_params->wsize-1),
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t,
self);

self->ucp_memory_types = context_attr.memory_types;
if (self->ucp_memory_types & UCC_BIT(ucc_memtype_to_ucs[UCC_MEMORY_TYPE_CUDA])) {
/* TL MLX5 needs this information */
self->super.super.ucc_context->tl_caps |= UCC_TL_UCP_CUDA_ENABLED;
}

worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
switch (params->thread_mode) {
case UCC_THREAD_SINGLE:
Expand Down
7 changes: 7 additions & 0 deletions src/core/ucc_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ typedef struct ucc_context_id {
#define UCC_CTX_ID_EQUAL(_id1, _id2) (UCC_PROC_INFO_EQUAL((_id1).pi, (_id2).pi) \
&& (_id1).seq_num == (_id2).seq_num)

enum ucc_tl_capabilities {
/* capabalities that every TL needs to be aware of
* about other TLs */
UCC_TL_UCP_CUDA_ENABLED = UCC_BIT(0)
};

enum {
/* all ranks have identical set of TLs*/
UCC_ADDR_STORAGE_FLAG_TLS_SYMMETRIC = UCC_BIT(0),
Expand Down Expand Up @@ -78,6 +84,7 @@ typedef struct ucc_context {
uint64_t cl_flags;
ucc_tl_team_t *service_team;
int32_t throttle_progress;
enum ucc_tl_capabilities tl_caps;
} ucc_context_t;

typedef struct ucc_context_config {
Expand Down
Loading