Skip to content

Commit

Permalink
Merge pull request #816 from rapidsai/branch-0.23
Browse files Browse the repository at this point in the history
[RELEASE] ucx-py v0.23
  • Loading branch information
raydouglass authored Dec 8, 2021
2 parents 5577fb1 + fbd6071 commit 9971bd1
Show file tree
Hide file tree
Showing 37 changed files with 1,092 additions and 220 deletions.
182 changes: 137 additions & 45 deletions benchmarks/send-recv-core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@
blocking_am_send,
blocking_recv,
blocking_send,
non_blocking_recv,
non_blocking_send,
)

mp = mp.get_context("spawn")

WireupMessage = bytearray(b"wireup")


def register_am_allocators(args, worker):
if not args.enable_am:
Expand Down Expand Up @@ -120,53 +124,69 @@ def server(queue, args):
# out of scope too early.
ep = None

finished_lock = Lock()
op_lock = Lock()
finished = [0]
outstanding = [0]

def op_started():
with op_lock:
outstanding[0] += 1

def op_completed():
with op_lock:
outstanding[0] -= 1
finished[0] += 1

def _send_handle(request, exception, msg):
# Notice, we pass `msg` to the handler in order to make sure
# it doesn't go out of scope prematurely.
assert exception is None

with finished_lock:
finished[0] += 1
op_completed()

def _tag_recv_handle(request, exception, ep, msg):
assert exception is None
ucx_api.tag_send_nb(
req = ucx_api.tag_send_nb(
ep, msg, msg.nbytes, tag=0, cb_func=_send_handle, cb_args=(msg,)
)
if req is None:
op_completed()

def _am_recv_handle(recv_obj, exception, ep):
assert exception is None
msg = Array(recv_obj)
ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

def _listener_handler(conn_request):
def _listener_handler(conn_request, msg):
global ep
ep = ucx_api.UCXEndpoint.create_from_conn_request(
worker,
conn_request,
endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
)

if not args.enable_am:
msg_recv_list = []
if not args.reuse_alloc:
for _ in range(args.n_iter):
msg_recv_list.append(xp.zeros(args.n_bytes, dtype="u1"))
else:
t = xp.zeros(args.n_bytes, dtype="u1")
for _ in range(args.n_iter):
msg_recv_list.append(t)

assert msg_recv_list[0].nbytes == args.n_bytes
# Wireup before starting to transfer data
if args.enable_am is True:
ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
else:
wireup = Array(bytearray(len(WireupMessage)))
op_started()
ucx_api.tag_recv_nb(
worker,
wireup,
wireup.nbytes,
tag=0,
cb_func=_tag_recv_handle,
cb_args=(ep, wireup),
)

for i in range(args.n_iter):
if args.enable_am is True:
ucx_api.am_recv_nb(ep, cb_func=_am_recv_handle, cb_args=(ep,))
else:
msg = Array(msg_recv_list[i])
if not args.reuse_alloc:
msg = Array(xp.zeros(args.n_bytes, dtype="u1"))

op_started()
ucx_api.tag_recv_nb(
worker,
msg,
Expand All @@ -176,14 +196,30 @@ def _listener_handler(conn_request):
cb_args=(ep, msg),
)

if not args.enable_am and args.reuse_alloc:
msg = Array(xp.zeros(args.n_bytes, dtype="u1"))
else:
msg = None

listener = ucx_api.UCXListener(
worker=worker, port=args.port or 0, cb_func=_listener_handler
worker=worker, port=args.port or 0, cb_func=_listener_handler, cb_args=(msg,)
)
queue.put(listener.port)

while finished[0] != args.n_iter:
while outstanding[0] == 0:
worker.progress()

# +1 to account for wireup message
if args.delay_progress:
while finished[0] < args.n_iter + 1 and (
outstanding[0] >= args.max_outstanding
or finished[0] + args.max_outstanding >= args.n_iter + 1
):
worker.progress()
else:
while finished[0] != args.n_iter + 1:
worker.progress()


def client(queue, port, server_address, args):
if args.client_cpu_affinity >= 0:
Expand Down Expand Up @@ -225,48 +261,80 @@ def client(queue, port, server_address, args):
endpoint_error_handling=ucx_api.get_ucx_version() >= (1, 10, 0),
)

send_msg = xp.arange(args.n_bytes, dtype="u1")
if args.reuse_alloc:
recv_msg = xp.zeros(args.n_bytes, dtype="u1")

if args.enable_am:
msg = xp.arange(args.n_bytes, dtype="u1")
blocking_am_send(worker, ep, send_msg)
blocking_am_recv(worker, ep)
else:
msg_send_list = []
msg_recv_list = []
if not args.reuse_alloc:
for i in range(args.n_iter):
msg_send_list.append(xp.arange(args.n_bytes, dtype="u1"))
msg_recv_list.append(xp.zeros(args.n_bytes, dtype="u1"))
else:
t1 = xp.arange(args.n_bytes, dtype="u1")
t2 = xp.zeros(args.n_bytes, dtype="u1")
for i in range(args.n_iter):
msg_send_list.append(t1)
msg_recv_list.append(t2)
assert msg_send_list[0].nbytes == args.n_bytes
assert msg_recv_list[0].nbytes == args.n_bytes
wireup_recv = bytearray(len(WireupMessage))
blocking_send(worker, ep, WireupMessage)
blocking_recv(worker, ep, wireup_recv)

op_lock = Lock()
finished = [0]
outstanding = [0]

def maybe_progress():
while outstanding[0] >= args.max_outstanding:
worker.progress()

def op_started():
with op_lock:
outstanding[0] += 1

def op_completed():
with op_lock:
outstanding[0] -= 1
finished[0] += 1

if args.cuda_profile:
xp.cuda.profiler.start()

times = []
for i in range(args.n_iter):
start = clock()

if args.enable_am:
blocking_am_send(worker, ep, msg)
blocking_am_send(worker, ep, send_msg)
blocking_am_recv(worker, ep)
else:
blocking_send(worker, ep, msg_send_list[i])
blocking_recv(worker, ep, msg_recv_list[i])
if not args.reuse_alloc:
recv_msg = xp.zeros(args.n_bytes, dtype="u1")

if args.delay_progress:
maybe_progress()
non_blocking_send(worker, ep, send_msg, op_started, op_completed)
maybe_progress()
non_blocking_recv(worker, ep, recv_msg, op_started, op_completed)
else:
blocking_send(worker, ep, send_msg)
blocking_recv(worker, ep, recv_msg)

stop = clock()
times.append(stop - start)

if args.delay_progress:
while finished[0] != 2 * args.n_iter:
worker.progress()

if args.cuda_profile:
xp.cuda.profiler.stop()

assert len(times) == args.n_iter
delay_progress_str = (
f"True ({args.max_outstanding})" if args.delay_progress is True else "False"
)
print("Roundtrip benchmark")
print("--------------------------")
print(f"n_iter | {args.n_iter}")
print(f"n_bytes | {format_bytes(args.n_bytes)}")
print(f"object | {args.object_type}")
print(f"reuse alloc | {args.reuse_alloc}")
print(f"transfer API | {'AM' if args.enable_am else 'TAG'}")
print(f"delay progress | {delay_progress_str}")
print(f"UCX_TLS | {ucp.get_config()['TLS']}")
print(f"UCX_NET_DEVICES | {ucp.get_config()['NET_DEVICES']}")
print("==========================")
Expand All @@ -290,13 +358,14 @@ def client(queue, port, server_address, args):
med = format_bytes(2 * args.n_bytes / np.median(times))
print(f"Average | {avg}/s")
print(f"Median | {med}/s")
print("--------------------------")
print("Iterations")
print("--------------------------")
for i, t in enumerate(times):
ts = format_bytes(2 * args.n_bytes / t)
ts = (" " * (9 - len(ts))) + ts
print("%03d |%s/s" % (i, ts))
if not args.no_detailed_report:
print("--------------------------")
print("Iterations")
print("--------------------------")
for i, t in enumerate(times):
ts = format_bytes(2 * args.n_bytes / t)
ts = (" " * (9 - len(ts))) + ts
print("%03d |%s/s" % (i, ts))


def parse_args():
Expand Down Expand Up @@ -417,6 +486,29 @@ def parse_args():
action="store_true",
help="Use Active Message API instead of TAG for transfers",
)
parser.add_argument(
"--delay-progress",
default=False,
action="store_true",
help="Delay ucp_worker_progress calls until a minimum number of "
"outstanding operations is reached, implies non-blocking send/recv. "
"The --max-outstanding argument may be used to control number of "
"maximum outstanding operations. (Default: disabled)",
)
parser.add_argument(
"--max-outstanding",
metavar="N",
default=32,
type=int,
help="Number of maximum outstanding operations, see --delay-progress. "
"(Default: 32)",
)
parser.add_argument(
"--no-detailed-report",
default=False,
action="store_true",
help="Disable detailed report per iteration.",
)

args = parser.parse_args()
if args.cuda_profile and args.object_type == "numpy":
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/send-recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ async def server_handler(ep):

loop = asyncio.get_event_loop()
loop.run_until_complete(run())
loop.close()


def client(queue, port, server_address, args):
Expand Down Expand Up @@ -203,7 +202,7 @@ async def run():

loop = asyncio.get_event_loop()
loop.run_until_complete(run())
loop.close()

times = queue.get()
assert len(times) == args.n_iter
print("Roundtrip benchmark")
Expand Down
Loading

0 comments on commit 9971bd1

Please sign in to comment.