Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update explicit-comms for dask-expr support #1323

Merged
merged 23 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dask_cuda/benchmarks/local_cudf_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import dask
from dask.base import tokenize
from dask.dataframe.core import new_dd_object
from dask.distributed import performance_report, wait
from dask.utils import format_bytes, parse_bytes

Expand All @@ -20,6 +19,7 @@
print_separator,
print_throughput_bandwidth,
)
from dask_cuda.utils import _make_collection

# Benchmarking cuDF merge operation based on
# <https://gist.github.com/rjzamora/0ffc35c19b5180ab04bbf7c793c45955>
Expand Down Expand Up @@ -123,7 +123,7 @@ def get_random_ddf(chunk_size, num_chunks, frac_match, chunk_type, args):
for i, part in enumerate(parts)
}

ddf = new_dd_object(graph, name, meta, divisions)
ddf = _make_collection(graph, name, meta, divisions)

if chunk_type == "build":
if not args.no_shuffle:
Expand Down
7 changes: 3 additions & 4 deletions dask_cuda/benchmarks/local_cudf_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import dask
import dask.dataframe
from dask.dataframe.core import new_dd_object
from dask.dataframe.shuffle import shuffle
from dask.distributed import Client, performance_report, wait
from dask.utils import format_bytes, parse_bytes

Expand All @@ -22,6 +20,7 @@
print_separator,
print_throughput_bandwidth,
)
from dask_cuda.utils import _make_collection

try:
import cupy
Expand All @@ -33,7 +32,7 @@


def shuffle_dask(df, args):
result = shuffle(df, index="data", shuffle="tasks", ignore_index=args.ignore_index)
result = df.shuffle("data", shuffle_method="tasks", ignore_index=args.ignore_index)
if args.backend == "dask-noop":
result = as_noop(result)
t1 = perf_counter()
Expand Down Expand Up @@ -105,7 +104,7 @@ def create_data(

df_meta = create_df(0, args.type)
divs = [None] * (len(dsk) + 1)
ret = new_dd_object(dsk, name, df_meta, divs).persist()
ret = _make_collection(dsk, name, df_meta, divs).persist()
wait(ret)

data_processed = args.in_parts * args.partition_size
Expand Down
12 changes: 8 additions & 4 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
import dask.utils
import distributed.worker
from dask.base import tokenize
from dask.dataframe.core import DataFrame, Series, _concat as dd_concat, new_dd_object
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker

from dask_cuda.utils import _make_collection

from .. import comms

T = TypeVar("T")
Expand Down Expand Up @@ -468,8 +471,9 @@ def shuffle(
npartitions = df.npartitions

# Step (a):
df = df.persist() # Make sure optimizations are apply on the existing graph
df = df.persist() # Make sure optimizations are applied on the existing graph
wait([df]) # Make sure all keys has been materialized on workers
persisted_keys = [f.key for f in c.client.futures_of(df)]
name = (
"explicit-comms-shuffle-"
f"{tokenize(df, column_names, npartitions, ignore_index)}"
Expand All @@ -479,7 +483,7 @@ def shuffle(
# Stage all keys of `df` on the workers and cancel them, which makes it possible
# for the shuffle to free memory as the partitions of `df` are consumed.
# See CommsContext.stage_keys() for a description of staging.
rank_to_inkeys = c.stage_keys(name=name, keys=df.__dask_keys__())
rank_to_inkeys = c.stage_keys(name=name, keys=persisted_keys)
c.client.cancel(df)

# Get batchsize
Expand Down Expand Up @@ -538,7 +542,7 @@ def shuffle(

# Create a distributed Dataframe from all the pieces
divs = [None] * (len(dsk) + 1)
ret = new_dd_object(dsk, name, df_meta, divs).persist()
ret = _make_collection(dsk, name, df_meta, divs).persist()
wait([ret])

# Release all temporary dataframes
Expand Down
6 changes: 0 additions & 6 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

# Skip these tests when dask-expr is active (for now)
pytestmark = pytest.mark.skipif(
dask.config.get("dataframe.query-planning", None) is not False,
reason="https://github.com/rapidsai/dask-cuda/issues/1311",
)

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")

Expand Down
15 changes: 15 additions & 0 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,18 @@ def get_rmm_memory_resource_stack(mr) -> list:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None

rjzamora marked this conversation as resolved.
Show resolved Hide resolved
def _make_collection(graph, name, meta, divisions):
# Create a DataFrame collection from a task graph.
# Accounts for legacy vs dask-expr API
try:
# New expression-based API
from dask.dataframe import from_graph

keys = [(name, i) for i in range(len(divisions))]
return from_graph(graph, meta, divisions, keys, name)
except ImportError:
# Legacy API
from dask.dataframe.core import new_dd_object

return new_dd_object(graph, name, meta, divisions)
Loading