Skip to content

Commit

Permalink
Add tests for GPU resource definition
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 23, 2024
1 parent 12389a5 commit 1d6d5e8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
42 changes: 42 additions & 0 deletions dask_cuda/tests/test_dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,45 @@ def test_worker_cudf_spill_warning(enable_cudf_spill_warning): # noqa: F811
assert b"UserWarning: cuDF spilling is enabled" in ret.stderr
else:
assert b"UserWarning: cuDF spilling is enabled" not in ret.stderr


def test_worker_gpu_resource(loop): # noqa: F811
with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
with popen(
[
"dask",
"cuda",
"worker",
"127.0.0.1:9369",
"--no-dashboard",
]
):
with Client("127.0.0.1:9369", loop=loop) as client:
assert wait_workers(client, n_gpus=get_n_gpus())

workers = client.scheduler_info()["workers"]
for v in workers.values():
assert "GPU" in v["resources"]
assert v["resources"]["GPU"] == 1


def test_worker_gpu_resource_user_defined(loop): # noqa: F811
with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]):
with popen(
[
"dask",
"cuda",
"worker",
"127.0.0.1:9369",
"--resources",
"'GPU=55'",
"--no-dashboard",
]
):
with Client("127.0.0.1:9369", loop=loop) as client:
assert wait_workers(client, n_gpus=get_n_gpus())

workers = client.scheduler_info()["workers"]
for v in workers.values():
assert "GPU" in v["resources"]
assert v["resources"]["GPU"] == 55
20 changes: 20 additions & 0 deletions dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,26 @@ async def test_all_to_all():
assert all(all_data.count(i) == n_workers for i in all_data)


@gen_test(timeout=20)
async def test_worker_gpu_resource():
async with LocalCUDACluster(asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
workers = client.scheduler_info()["workers"]
for v in workers.values():
assert "GPU" in v["resources"]
assert v["resources"]["GPU"] == 1


@gen_test(timeout=20)
async def test_worker_gpu_resource_user_defined():
async with LocalCUDACluster(asynchronous=True, resources={"GPU": 55}) as cluster:
async with Client(cluster, asynchronous=True) as client:
workers = client.scheduler_info()["workers"]
for v in workers.values():
assert "GPU" in v["resources"]
assert v["resources"]["GPU"] == 55


@gen_test(timeout=20)
async def test_rmm_pool():
rmm = pytest.importorskip("rmm")
Expand Down

0 comments on commit 1d6d5e8

Please sign in to comment.