Skip to content

Commit

Permalink
Merge pull request #139 from lithomas1/dask-fft
Browse files Browse the repository at this point in the history
Wrap fft for dask
  • Loading branch information
asmeurer authored Oct 16, 2024
2 parents e9da040 + 2182b4f commit 5affae5
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/workflows/array-api-tests-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ jobs:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: dask
package-version: '>= 2024.9.0'
module-name: dask.array
extra-requires: numpy
pytest-extra-args: --disable-deadline --max-examples=5
3 changes: 2 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
# min version of dask we needs drops support for python 3.9
python-version: ${{ inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']') || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }}

steps:
- name: Checkout array-api-compat
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
__array_api_version__ = '2022.12'

__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
33 changes: 13 additions & 20 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@

import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
iinfo,
finfo,
bool_ as bool,
float32,
float64,
Expand All @@ -29,8 +25,6 @@
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)
Expand Down Expand Up @@ -206,19 +200,18 @@ def _isscalar(a):

return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)

# exclude these from all since
# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']

common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]

__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool',
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'e',
'inf', 'nan', 'pi', 'newaxis', 'float32',
'float64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo',
'can_cast', 'result_type']

_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
_all_ignore = ["get_xp", "da", "np"]
24 changes: 24 additions & 0 deletions array_api_compat/dask/array/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dask.array.fft import * # noqa: F403
# dask.array.fft doesn't have __all__. If it is added, replace this with
#
# from dask.array.fft import __all__ as linalg_all
_n = {}
exec('from dask.array.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n

from ...common import _fft
from ..._internal import get_xp

import dask.array as da

fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)

__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"]

del get_xp
del da
del fft_all
del _fft
15 changes: 0 additions & 15 deletions dask-skips.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,2 @@
# FFT isn't conformant
array_api_tests/test_fft.py
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfftn]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.hfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ihfft]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftfreq]
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftfreq]

# slow and not implemented in dask
array_api_tests/test_linalg.py::test_matrix_power

0 comments on commit 5affae5

Please sign in to comment.