Skip to content

Commit

Permalink
ENH: Dask: sort and argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 23, 2025
1 parent 8a79994 commit c0f8617
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 25 deletions.
117 changes: 110 additions & 7 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from ...common import _aliases
from typing import Callable

from ...common import _aliases, array_namespace

from ..._internal import get_xp

Expand Down Expand Up @@ -29,24 +31,32 @@
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union

from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
from ...common._typing import (
Device,
Dtype,
Array,
NestedSequence,
SupportsBufferProtocol,
)

import dask.array as da

isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)


# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Optional[Device] = None
device: Optional[Device] = None,
) -> Array:
"""
Array API compatibility wrapper for astype().
Expand All @@ -61,8 +71,10 @@ def astype(
x = x.astype(dtype)
return x.copy() if copy else x


# Common aliases


# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
Expand Down Expand Up @@ -189,6 +201,7 @@ def asarray(
concatenate as concat,
)


# dask.array.clip does not work unless all three arguments are provided.
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
Expand All @@ -205,8 +218,10 @@ def clip(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""

def _isscalar(a):
return isinstance(a, (int, float, type(None)))

min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

Expand All @@ -228,10 +243,98 @@ def _isscalar(a):

return astype(da.minimum(da.maximum(x, min), max), x.dtype)

# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']

_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], Array]]:
"""
Make sure that Array is not broken into multiple chunks along axis.
Returns
-------
x : Array
The input Array with a single chunk along axis.
restore : Callable[Array, Array]
function to apply to the output to rechunk it back into reasonable chunks
"""
if axis < 0:
axis += x.ndim
if x.numblocks[axis] < 2:
return x, lambda x: x

# Break chunks on other axes in an attempt to keep chunk size low
x = x.rechunk({i: -1 if i == axis else "auto" for i in range(x.ndim)})

# Rather than reconstructing the original chunks, which can be a
# very expensive affair, just break down oversized chunks without
# incurring in any transfers over the network.
# This has the downside of a risk of overchunking if the array is
# then used in operations against other arrays that match the
# original chunking pattern.
return x, lambda x: x.rechunk()


def sort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of sort() in Dask.
Warnings
--------
This function temporarily rechunks the array along `axis` to a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
x, restore = _ensure_single_chunk(x, axis)

meta_xp = array_namespace(x._meta)
x = da.map_blocks(
meta_xp.sort,
x,
axis=axis,
meta=x._meta,
dtype=x.dtype,
descending=descending,
stable=stable,
)

return restore(x)


def argsort(
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
) -> Array:
"""
Array API compatibility layer around the lack of argsort() in Dask.
See the corresponding documentation in the array library and/or the array API
specification for more details.
Warnings
--------
This function temporarily rechunks the array along `axis` into a single chunk.
This can be extremely inefficient and can lead to out-of-memory errors.
"""
x, restore = _ensure_single_chunk(x, axis)

meta_xp = array_namespace(x._meta)
dtype = meta_xp.argsort(x._meta).dtype
meta = meta_xp.astype(x._meta, dtype)
x = da.map_blocks(
meta_xp.argsort,
x,
axis=axis,
meta=meta,
dtype=dtype,
descending=descending,
stable=stable,
)

return restore(x)


_common_aliases = _aliases.__all__

__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand All @@ -242,4 +345,4 @@ def _isscalar(a):
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']

_all_ignore = ["get_xp", "da", "np"]
_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
22 changes: 5 additions & 17 deletions dask-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,13 @@ array_api_tests/test_array_object.py::test_setitem_masking
# Various indexing errors
array_api_tests/test_array_object.py::test_getitem_masking

# asarray(copy=False) is not yet implemented
# copied from numpy xfails, TODO: should this pass with dask?
array_api_tests/test_creation_functions.py::test_asarray_arrays

# zero division error, and typeerror: tuple indices must be integers or slices not tuple
array_api_tests/test_creation_functions.py::test_eye

# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]

# out[-1]=dask.aray<getitem ...> but should be some floating number
# out[-1]=dask.array<getitem ...> but should be some floating number
# (I think the test is not forcing the op to be computed?)
array_api_tests/test_creation_functions.py::test_linspace

Expand All @@ -48,15 +44,7 @@ array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0
array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]

# No sorting in dask
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
array_api_tests/test_has_names.py::test_has_names[sorting-sort]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_sorting_functions.py::test_sort
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]

# Array methods and attributes not already on np.ndarray cannot be wrapped
# Array methods and attributes not already on da.Array cannot be wrapped
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
Expand All @@ -76,6 +64,7 @@ array_api_tests/test_set_functions.py::test_unique_values
# fails for ndim > 2
array_api_tests/test_linalg.py::test_svdvals
array_api_tests/test_linalg.py::test_cholesky

# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
array_api_tests/test_linalg.py::test_tensordot

Expand Down Expand Up @@ -105,6 +94,8 @@ array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_det
array_api_tests/test_linalg.py::test_eigh
array_api_tests/test_linalg.py::test_eigvalsh
array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank
array_api_tests/test_linalg.py::test_pinv
array_api_tests/test_linalg.py::test_slogdet
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
Expand All @@ -115,9 +106,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power]
array_api_tests/test_has_names.py::test_has_names[linalg-pinv]
array_api_tests/test_has_names.py::test_has_names[linalg-slogdet]

array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_rank

# missing mode kw
# https://github.com/dask/dask/issues/10388
array_api_tests/test_linalg.py::test_qr
Expand Down
73 changes: 72 additions & 1 deletion tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager

import array_api_strict
import dask
import numpy as np
import pytest
Expand All @@ -20,9 +21,10 @@ def assert_no_compute():
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""

def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield

Expand All @@ -40,6 +42,7 @@ def test_assert_no_compute():

# Test no_compute for functions that use generic _aliases with xp=np


def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
Expand All @@ -59,6 +62,7 @@ def test_matmul_tensordot_no_compute(xp):

# Test no_compute for functions that are fully bespoke for dask


def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
Expand Down Expand Up @@ -88,6 +92,14 @@ def test_clip_no_compute(xp):
xp.clip(a, 1, 8)


@pytest.mark.parametrize("chunks", (5, 10))
def test_sort_argsort_nocompute(xp, chunks):
with assert_no_compute():
a = xp.arange(10, chunks=chunks)
xp.sort(a)
xp.argsort(a)


def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
Expand All @@ -106,3 +118,62 @@ def test_generators_are_lazy(xp):
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)


@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunks(xp, func, axis):
"""Test that sort and argsort are functionally correct when
the array is chunked along the sort axis, e.g. the sort is
not just local to each chunk.
"""
a = da.random.random((10, 10), chunks=(5, 5))
actual = getattr(xp, func)(a, axis=axis)
expect = getattr(np, func)(a.compute(), axis=axis)
np.testing.assert_array_equal(actual, expect)


@pytest.mark.parametrize(
"shape,chunks",
[
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Sort chunks can be 128 MiB each; no need for final rechunk.
((20_000, 20_000), "auto"),
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
((2, 2**30 * 3 // 16), "auto"),
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
# Surely the user must know what they're doing, so don't
# perform the final rechunk.
((2, 2**30 * 3 // 16), (1, -1)),
],
)
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
"""
Test that sort and argsort produce reasonably-sized chunks
in the output array, even if they had to go through a singular
huge one to perform the operation.
"""
a = da.random.random(shape, chunks=chunks)
b = getattr(xp, func)(a)
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
assert (
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
or b.chunks == a.chunks
)


@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_meta(xp, func):
"""Test meta-namespace other than numpy"""
typ = type(array_api_strict.asarray(0))
a = da.random.random(10)
b = a.map_blocks(array_api_strict.asarray)
assert isinstance(b._meta, typ)
c = getattr(xp, func)(b)
assert isinstance(c._meta, typ)
d = c.compute()
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
assert isinstance(d, typ)
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))

0 comments on commit c0f8617

Please sign in to comment.