Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporary dask-cudf workaround for categorical sorting #15801

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from dask import config
from dask.dataframe.core import is_dataframe_like
from dask.dataframe.dispatch import is_categorical_dtype

import cudf

Expand Down Expand Up @@ -82,24 +81,6 @@ def from_dict(cls, *args, **kwargs):
with config.set({"dataframe.backend": "cudf"}):
return DXDataFrame.from_dict(*args, **kwargs)

def sort_values(
self,
by,
**kwargs,
):
# Raise if the first column is categorical, otherwise the
# upstream divisions logic may produce errors
# (See: https://github.com/rapidsai/cudf/issues/11795)
check_by = by[0] if isinstance(by, list) else by
if is_categorical_dtype(self.dtypes.get(check_by, None)):
raise NotImplementedError(
"Dask-cudf does not support sorting on categorical "
"columns when query-planning is enabled. Please use "
"the legacy API for now."
f"\n{_LEGACY_WORKAROUND}",
)
return super().sort_values(by, **kwargs)

def groupby(
self,
by,
Expand Down
25 changes: 25 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_expr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import functools

import dask_expr._shuffle as _shuffle_module
from dask_expr import new_collection
from dask_expr._cumulative import CumulativeBlockwise
from dask_expr._expr import Expr, VarColumns
from dask_expr._reductions import Reduction, Var

from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty
from dask.dataframe.dispatch import is_categorical_dtype

##
## Custom expression patching
Expand Down Expand Up @@ -121,3 +124,25 @@ def _patched_var(


Expr.var = _patched_var


# Temporary work-around for missing cudf + categorical support
# See: https://github.com/rapidsai/cudf/issues/11795
# TODO: Fix RepartitionQuantiles and remove this in cudf>24.06

_original_get_divisions = _shuffle_module._get_divisions


def _patched_get_divisions(frame, other, *args, **kwargs):
# NOTE: The following two lines contains the "patch"
# (we simply convert the partitioning column to pandas)
if is_categorical_dtype(other._meta.dtype) and hasattr(
other.frame._meta, "to_pandas"
):
other = new_collection(other).to_backend("pandas")._expr

# Call "original" function
return _original_get_divisions(frame, other, *args, **kwargs)


_shuffle_module._get_divisions = _patched_get_divisions
23 changes: 2 additions & 21 deletions python/dask_cudf/dask_cudf/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import cudf

import dask_cudf
from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr
from dask_cudf.tests.utils import xfail_dask_expr


@pytest.mark.parametrize("ascending", [True, False])
Expand All @@ -20,12 +20,7 @@
"a",
"b",
"c",
pytest.param(
"d",
marks=xfail_dask_expr(
"Possible segfault when sorting by categorical column.",
),
),
"d",
["a", "b"],
["c", "d"],
],
Expand All @@ -47,20 +42,6 @@ def test_sort_values(nelem, nparts, by, ascending):
dd.assert_eq(got, expect, check_index=False)


@pytest.mark.parametrize("by", ["b", ["b", "a"]])
def test_sort_values_categorical_raises(by):
df = cudf.DataFrame()
df["a"] = np.ascontiguousarray(np.arange(10)[::-1])
df["b"] = df["a"].astype("category")
ddf = dd.from_pandas(df, npartitions=10)

if QUERY_PLANNING_ON:
with pytest.raises(
NotImplementedError, match="sorting on categorical"
):
ddf.sort_values(by=by)


@pytest.mark.parametrize("ascending", [True, False])
@pytest.mark.parametrize("by", ["a", "b", ["a", "b"]])
def test_sort_values_single_partition(by, ascending):
Expand Down
Loading