diff --git a/.github/fftw-env.yaml b/.github/fftw-env.yaml new file mode 100644 index 0000000..8ab39b4 --- /dev/null +++ b/.github/fftw-env.yaml @@ -0,0 +1,8 @@ +name: withfftw +channels: + - conda-forge + - defaults +dependencies: + - pyfftw + - pip: + - methodtools diff --git a/.github/workflows/test-with-warnings.yaml b/.github/workflows/test-with-warnings.yaml index aa006ae..5f2e2fc 100644 --- a/.github/workflows/test-with-warnings.yaml +++ b/.github/workflows/test-with-warnings.yaml @@ -16,8 +16,7 @@ jobs: with: fetch-depth: 1 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} diff --git a/.github/workflows/testsuite.yaml b/.github/workflows/testsuite.yaml index 0af3333..97127e4 100644 --- a/.github/workflows/testsuite.yaml +++ b/.github/workflows/testsuite.yaml @@ -17,16 +17,24 @@ jobs: with: fetch-depth: 1 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python }} + mamba-version: "*" + channels: conda-forge,defaults + channel-priority: true + activate-environment: withfftw + environment-file: .github/fftw-env.yaml - name: Install Test Deps + shell: bash -el {0} run: | + which pip + python --version pip install .[tests,fftw] - name: Run Tests + shell: bash -el {0} run: | python -m pytest --cov=powerbox --cov-config=.coveragerc --cov-report xml:./coverage.xml --junitxml=test-reports/xunit.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ab9ce2..b0b2f42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.0 + rev: 24.4.2 hooks: - id: black diff --git a/pyproject.toml b/pyproject.toml index 700196e..16dc6ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ tests = [ "pytest", "pytest-cov", "scipy", - "pyfftw" ] docs = [ "sphinx", @@ -53,7 +52,7 @@ docs = [ "packaging", # required for camb ] dev = [ - "powerbox[tests,docs]", + "powerbox[tests,docs,fftw]", "pre-commit" ] fftw = [ diff --git a/src/powerbox/dft_backend.py b/src/powerbox/dft_backend.py index 1670c23..c6d042e 100644 --- a/src/powerbox/dft_backend.py +++ b/src/powerbox/dft_backend.py @@ -6,6 +6,7 @@ import warnings from abc import ABC, abstractmethod from functools import cache +from multiprocessing import cpu_count try: import pyfftw @@ -79,17 +80,26 @@ class FFTW(FFTBackend): """FFT backend using pyfftw.""" def __init__(self, nthreads=None): - if nthreads is None: - from multiprocessing import cpu_count - - nthreads = cpu_count() - - self.nthreads = nthreads try: import pyfftw except ImportError: raise ImportError("pyFFTW could not be imported...") + try: + pyfftw.builders._utils._default_threads(4) + except ValueError: + if nthreads and nthreads > 1: + warnings.warn( + "pyFFTW was not installed with multithreading. Using 1 thread.", + stacklevel=2, + ) + nthreads = 1 + + if nthreads is None: + nthreads = cpu_count() + + self.nthreads = nthreads + self._fftshift = pyfftw.interfaces.numpy_fft.fftshift self._ifftshift = pyfftw.interfaces.numpy_fft.ifftshift self._fftfreq = pyfftw.interfaces.numpy_fft.fftfreq @@ -111,7 +121,7 @@ def get_fft_backend(nthreads=None): Will return the Numpy backend if nthreads is None, otherwise the FFTW backend with the given number of threads. """ - if nthreads is None or nthreads > 0: + if nthreads is None or nthreads > 1: try: fftbackend = FFTW(nthreads=nthreads) except ImportError: diff --git a/tests/test_fft.py b/tests/test_fft.py index 1dd0364..a597151 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,5 +1,6 @@ import pytest +import contextlib import numpy as np from powerbox.dft import fft, fftfreq, fftshift, ifft, ifftshift @@ -14,7 +15,23 @@ (1, 1, 0, 1), ] -BACKENDS = [NumpyFFT(), FFTW(nthreads=1), FFTW(nthreads=2)] +BACKENDS = [ + NumpyFFT(), +] + +HAVE_FFTW = False +HAVE_FFTW_MULTITHREAD = False + +with contextlib.suppress(ValueError, ImportError): + import pyfftw + + BACKENDS.append(FFTW(nthreads=1)) + HAVE_FFTW = True + + pyfftw.builders._utils._default_threads(4) + + BACKENDS.append(FFTW(nthreads=2)) + HAVE_FFTW_MULTITHREAD = True def gauss_ft(k, a, b, n=2): @@ -116,8 +133,14 @@ def test_mixed_2d_fb(g2d, a, b, ainv, binv, backend): assert np.max(np.abs(fx.real - analytic_mix(xgrid, a, b, ainv, binv))) < 1e-10 +NTHREADS_TO_CHECK = (None, 1, False) + +if HAVE_FFTW_MULTITHREAD: + NTHREADS_TO_CHECK += (2,) + + @pytest.mark.parametrize("a,b, ainv, binv", ABCOMBOS) -@pytest.mark.parametrize("nthreads", (None, 1, 2, False)) +@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK) def test_mixed_2d_bf(g2d, a, b, ainv, binv, nthreads): Fk, freq = ifft(g2d["fx"], Lk=g2d["L"], a=ainv, b=binv, nthreads=nthreads) L = -2 * np.min(freq) @@ -127,7 +150,7 @@ def test_mixed_2d_bf(g2d, a, b, ainv, binv, nthreads): assert np.max(np.abs(fx.real - analytic_mix(xgrid, a, binv, ainv, b))) < 1e-10 -@pytest.mark.parametrize("nthreads", (None, 1, 2, False)) +@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK) def test_fftshift(nthreads): x = np.linspace(0, 1, 11) @@ -135,7 +158,7 @@ def test_fftshift(nthreads): assert np.all(x == y) -@pytest.mark.parametrize("nthreads", (None, 1, 2, False)) +@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK) @pytest.mark.parametrize("n", (10, 11)) def test_fftfreq(nthreads, n): freqs = fftfreq(n, nthreads=nthreads)