Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prov/rxm: add FI_PEER support to rxm #10510

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/ofi_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -955,13 +955,15 @@ struct rxm_av {
struct fid_peer_av peer_av;
struct fid_av *util_coll_av;
struct fid_av *offload_coll_av;
void (*foreach_ep)(struct util_av *av, struct util_ep *util_ep);
};

int rxm_util_av_open(struct fid_domain *domain_fid, struct fi_av_attr *attr,
struct fid_av **fid_av, void *context, size_t conn_size,
void (*remove_handler)(struct util_ep *util_ep,
struct util_peer_addr *peer));
size_t rxm_av_max_peers(struct rxm_av *av);
struct util_peer_addr *peer),
void (*foreach_ep)(struct util_av *av,
struct util_ep *ep));size_t rxm_av_max_peers(struct rxm_av *av);
void rxm_ref_peer(struct util_peer_addr *peer);
void *rxm_av_alloc_conn(struct rxm_av *av);
void rxm_av_free_conn(struct rxm_av *av, void *conn_ctx);
Expand Down
208 changes: 46 additions & 162 deletions prov/rxm/src/rxm.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ do { \
extern struct fi_provider rxm_prov;
extern struct util_prov rxm_util_prov;

extern struct fi_ops_msg rxm_msg_ops;
extern struct fi_ops_msg rxm_msg_ops, rxm_no_recv_msg_ops;
extern struct fi_ops_msg rxm_msg_thru_ops;
extern struct fi_ops_tagged rxm_tagged_ops;
extern struct fi_ops_tagged rxm_tagged_ops, rxm_no_recv_tagged_ops;
extern struct fi_ops_tagged rxm_tagged_thru_ops;
extern struct fi_ops_rma rxm_rma_ops;
extern struct fi_ops_rma rxm_rma_thru_ops;
Expand Down Expand Up @@ -265,6 +265,8 @@ struct rxm_fabric {
struct rxm_domain {
struct util_domain util_domain;
struct fid_domain *msg_domain;
struct fid_ep rx_ep;
struct fid_peer_srx *srx;
size_t max_atomic_size;
size_t rx_post_size;
uint64_t mr_key;
Expand Down Expand Up @@ -417,13 +419,15 @@ struct rxm_pkt {
char data[];
};

enum rxm_sar_seg_type {
RXM_SAR_SEG_FIRST = 1,
RXM_SAR_SEG_MIDDLE = 2,
RXM_SAR_SEG_LAST = 3,
};

union rxm_sar_ctrl_data {
struct {
enum rxm_sar_seg_type {
RXM_SAR_SEG_FIRST = 1,
RXM_SAR_SEG_MIDDLE = 2,
RXM_SAR_SEG_LAST = 3,
} seg_type : 2;
enum rxm_sar_seg_type seg_type;
uint32_t offset;
};
uint64_t align;
Expand All @@ -441,24 +445,29 @@ rxm_sar_set_seg_type(struct ofi_ctrl_hdr *ctrl_hdr, enum rxm_sar_seg_type seg_ty
((union rxm_sar_ctrl_data *)&(ctrl_hdr->ctrl_data))->seg_type = seg_type;
}

struct rxm_recv_match_attr {
fi_addr_t addr;
uint64_t tag;
uint64_t ignore;
};

struct rxm_unexp_msg {
struct dlist_entry entry;
fi_addr_t addr;
uint64_t tag;
};

struct rxm_iov {
struct iovec iov[RXM_IOV_LIMIT];
void *desc[RXM_IOV_LIMIT];
uint8_t count;
};

struct rxm_proto_info {
/* Used for SAR protocol */
struct {
struct dlist_entry entry;
struct dlist_entry pkt_list;
struct fi_peer_rx_entry *rx_entry;
size_t total_recv_len;
struct rxm_conn *conn;
uint64_t msg_id;
} sar;
/* Used for Rendezvous protocol */
struct {
/* This is used to send RNDV ACK */
struct rxm_tx_buf *tx_buf;
} rndv;
};

struct rxm_buf {
/* Must stay at top */
struct fi_context fi_context;
Expand All @@ -476,9 +485,10 @@ struct rxm_rx_buf {
/* MSG EP / shared context to which bufs would be posted to */
struct fid_ep *rx_ep;
struct dlist_entry repost_entry;
struct dlist_entry unexp_entry;
struct rxm_conn *conn; /* msg ep data was received on */
struct rxm_recv_entry *recv_entry;
struct rxm_unexp_msg unexp_msg;
struct fi_peer_rx_entry *peer_entry;
struct rxm_proto_info *proto_info;
uint64_t comp_flags;
struct fi_recv_context recv_context;
bool repost;
Expand Down Expand Up @@ -606,49 +616,6 @@ struct rxm_deferred_tx_entry {
};
};

struct rxm_recv_entry {
struct dlist_entry entry;
struct rxm_iov rxm_iov;
fi_addr_t addr;
void *context;
uint64_t flags;
uint64_t tag;
uint64_t ignore;
uint64_t comp_flags;
size_t total_len;
struct rxm_recv_queue *recv_queue;

/* Used for SAR protocol */
struct {
struct dlist_entry entry;
size_t total_recv_len;
struct rxm_conn *conn;
uint64_t msg_id;
} sar;
/* Used for Rendezvous protocol */
struct {
/* This is used to send RNDV ACK */
struct rxm_tx_buf *tx_buf;
} rndv;
};
OFI_DECLARE_FREESTACK(struct rxm_recv_entry, rxm_recv_fs);

enum rxm_recv_queue_type {
RXM_RECV_QUEUE_UNSPEC,
RXM_RECV_QUEUE_MSG,
RXM_RECV_QUEUE_TAGGED,
};

struct rxm_recv_queue {
struct rxm_ep *rxm_ep;
enum rxm_recv_queue_type type;
struct rxm_recv_fs *fs;
struct dlist_entry recv_list;
struct dlist_entry unexp_msg_list;
dlist_func_t *match_recv;
dlist_func_t *match_unexp;
};

struct rxm_eager_ops {
void (*comp_tx)(struct rxm_ep *rxm_ep,
struct rxm_tx_buf *tx_eager_buf);
Expand Down Expand Up @@ -688,6 +655,8 @@ struct rxm_ep {
struct fi_ops_transfer_peer *offload_coll_peer_xfer_ops;
uint64_t offload_coll_mask;

struct fid_peer_srx *srx;

struct fid_cq *msg_cq;
uint64_t msg_cq_last_poll;
size_t comp_per_progress;
Expand All @@ -701,7 +670,6 @@ struct rxm_ep {
bool do_progress;
bool enable_direct_send;

size_t min_multi_recv_size;
size_t buffered_min;
size_t buffered_limit;
size_t inject_limit;
Expand All @@ -713,15 +681,13 @@ struct rxm_ep {
struct ofi_bufpool *rx_pool;
struct ofi_bufpool *tx_pool;
struct ofi_bufpool *coll_pool;
struct ofi_bufpool *proto_info_pool;

struct rxm_pkt *inject_pkt;

struct dlist_entry deferred_queue;
struct dlist_entry rndv_wait_list;

struct rxm_recv_queue recv_queue;
struct rxm_recv_queue trecv_queue;
struct ofi_bufpool *multi_recv_pool;

struct rxm_eager_ops *eager_ops;
struct rxm_rndv_ops *rndv_ops;
};
Expand Down Expand Up @@ -755,11 +721,15 @@ int rxm_cq_open(struct fid_domain *domain, struct fi_cq_attr *attr,
struct fid_cq **cq_fid, void *context);
ssize_t rxm_handle_rx_buf(struct rxm_rx_buf *rx_buf);

int rxm_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
struct fid_ep **rx_ep, void *context);

int rxm_endpoint(struct fid_domain *domain, struct fi_info *info,
struct fid_ep **ep, void *context);

void rxm_cq_write_error(struct util_cq *cq, struct util_cntr *cntr,
void *op_context, int err);
void rxm_cq_write_tx_error(struct rxm_ep *rxm_ep, uint8_t op, void *op_context,
int err);
void rxm_cq_write_rx_error(struct rxm_ep *rxm_ep, uint8_t op, void *op_context,
int err);
void rxm_cq_write_error_all(struct rxm_ep *rxm_ep, int err);
void rxm_handle_comp_error(struct rxm_ep *rxm_ep);
ssize_t rxm_handle_comp(struct rxm_ep *rxm_ep, struct fi_cq_data_entry *comp);
Expand Down Expand Up @@ -878,50 +848,6 @@ int rxm_msg_mr_reg_internal(struct rxm_domain *rxm_domain, const void *buf,
size_t len, uint64_t acs, uint64_t flags,
struct fid_mr **mr);

static inline void rxm_cntr_incerr(struct util_cntr *cntr)
{
if (cntr)
cntr->cntr_fid.ops->adderr(&cntr->cntr_fid, 1);
}

static inline void
rxm_cq_write(struct util_cq *cq, void *context, uint64_t flags, size_t len,
void *buf, uint64_t data, uint64_t tag)
{
int ret;

FI_DBG(&rxm_prov, FI_LOG_CQ, "Reporting %s completion\n",
fi_tostr((void *) &flags, FI_TYPE_CQ_EVENT_FLAGS));

ret = ofi_cq_write(cq, context, flags, len, buf, data, tag);
if (ret) {
FI_WARN(&rxm_prov, FI_LOG_CQ,
"Unable to report completion\n");
assert(0);
}
if (cq->wait)
cq->wait->signal(cq->wait);
}

static inline void
rxm_cq_write_src(struct util_cq *cq, void *context, uint64_t flags, size_t len,
void *buf, uint64_t data, uint64_t tag, fi_addr_t addr)
{
int ret;

FI_DBG(&rxm_prov, FI_LOG_CQ, "Reporting %s completion\n",
fi_tostr((void *) &flags, FI_TYPE_CQ_EVENT_FLAGS));

ret = ofi_cq_write_src(cq, context, flags, len, buf, data, tag, addr);
if (ret) {
FI_WARN(&rxm_prov, FI_LOG_CQ,
"Unable to report completion\n");
assert(0);
}
if (cq->wait)
cq->wait->signal(cq->wait);
}

ssize_t rxm_get_conn(struct rxm_ep *rxm_ep, fi_addr_t addr,
struct rxm_conn **rxm_conn);

Expand Down Expand Up @@ -956,17 +882,10 @@ ssize_t
rxm_inject_send(struct rxm_ep *rxm_ep, struct rxm_conn *rxm_conn,
const void *buf, size_t len);

struct rxm_recv_entry *
rxm_recv_entry_get(struct rxm_ep *rxm_ep, const struct iovec *iov,
void **desc, size_t count, fi_addr_t src_addr,
uint64_t tag, uint64_t ignore, void *context,
uint64_t flags, struct rxm_recv_queue *recv_queue);
struct rxm_rx_buf *
rxm_get_unexp_msg(struct rxm_recv_queue *recv_queue, fi_addr_t addr,
uint64_t tag, uint64_t ignore);
ssize_t rxm_handle_unexp_sar(struct rxm_recv_queue *recv_queue,
struct rxm_recv_entry *recv_entry,
struct rxm_rx_buf *rx_buf);
ssize_t rxm_handle_unexp_sar(struct fi_peer_rx_entry *peer_entry);
int rxm_srx_context(struct fid_domain *domain, struct fi_rx_attr *attr,
struct fid_ep **rx_ep, void *context);

int rxm_post_recv(struct rxm_rx_buf *rx_buf);
void rxm_av_remove_handler(struct util_ep *util_ep,
struct util_peer_addr *peer);
Expand All @@ -987,41 +906,6 @@ rxm_free_rx_buf(struct rxm_rx_buf *rx_buf)
}
}

static inline void
rxm_recv_entry_release(struct rxm_recv_entry *entry)
{
if (entry->recv_queue)
ofi_freestack_push(entry->recv_queue->fs, entry);
else
ofi_buf_free(entry);
}

static inline void
rxm_cq_write_recv_comp(struct rxm_rx_buf *rx_buf, void *context, uint64_t flags,
size_t len, char *buf)
{
if (rx_buf->ep->util_coll_peer_xfer_ops &&
rx_buf->pkt.hdr.tag & RXM_PEER_XFER_TAG_FLAG) {
struct fi_cq_tagged_entry cqe = {
.tag = rx_buf->pkt.hdr.tag,
.op_context = rx_buf->recv_entry->context,
};
rx_buf->ep->util_coll_peer_xfer_ops->
complete(rx_buf->ep->util_coll_ep, &cqe, 0);
return;
}

if (rx_buf->ep->rxm_info->caps & FI_SOURCE)
rxm_cq_write_src(rx_buf->ep->util_ep.rx_cq, context,
flags, len, buf, rx_buf->pkt.hdr.data,
rx_buf->pkt.hdr.tag,
rx_buf->conn->peer->fi_addr);
else
rxm_cq_write(rx_buf->ep->util_ep.rx_cq, context,
flags, len, buf, rx_buf->pkt.hdr.data,
rx_buf->pkt.hdr.tag);
}

struct rxm_mr *rxm_mr_get_map_entry(struct rxm_domain *domain, uint64_t key);

struct rxm_recv_entry *
Expand Down
2 changes: 1 addition & 1 deletion prov/rxm/src/rxm_attr.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
OFI_RX_RMA_CAPS | FI_ATOMICS | FI_DIRECTED_RECV | \
FI_MULTI_RECV)

#define RXM_DOMAIN_CAPS (FI_LOCAL_COMM | FI_REMOTE_COMM)
#define RXM_DOMAIN_CAPS (FI_LOCAL_COMM | FI_REMOTE_COMM | FI_PEER)


/* Since we are a layering provider, the attributes for which we rely on the
Expand Down
13 changes: 5 additions & 8 deletions prov/rxm/src/rxm_conn.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct rxm_eq_cm_entry {
static void rxm_close_conn(struct rxm_conn *conn)
{
struct rxm_deferred_tx_entry *tx_entry;
struct rxm_recv_entry *rx_entry;
struct fi_peer_rx_entry *rx_entry;
struct rxm_rx_buf *buf;

FI_DBG(&rxm_prov, FI_LOG_EP_CTRL, "closing conn %p\n", conn);
Expand All @@ -74,16 +74,13 @@ static void rxm_close_conn(struct rxm_conn *conn)

while (!dlist_empty(&conn->deferred_sar_segments)) {
buf = container_of(conn->deferred_sar_segments.next,
struct rxm_rx_buf, unexp_msg.entry);
dlist_remove(&buf->unexp_msg.entry);
rxm_free_rx_buf(buf);
struct rxm_rx_buf, unexp_entry);
dlist_remove(&buf->unexp_entry);
}

while (!dlist_empty(&conn->deferred_sar_msgs)) {
rx_entry = container_of(conn->deferred_sar_msgs.next,
struct rxm_recv_entry, sar.entry);
dlist_remove(&rx_entry->entry);
rxm_recv_entry_release(rx_entry);
rx_entry = (struct fi_peer_rx_entry*)conn->deferred_sar_msgs.next;
rx_entry->srx->owner_ops->free_entry(rx_entry);
}
fi_close(&conn->msg_ep->fid);
rxm_flush_msg_cq(conn->ep);
Expand Down
Loading