Skip to content

Commit

Permalink
Merge pull request #59 from steven-murray/pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
[pre-commit.ci] pre-commit autoupdate
  • Loading branch information
steven-murray authored May 3, 2024
2 parents 853272a + 119f0fd commit 3c87176
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 18 deletions.
8 changes: 8 additions & 0 deletions .github/fftw-env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: withfftw
channels:
- conda-forge
- defaults
dependencies:
- pyfftw
- pip:
- methodtools
3 changes: 1 addition & 2 deletions .github/workflows/test-with-warnings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ jobs:
with:
fetch-depth: 1

- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}

Expand Down
12 changes: 10 additions & 2 deletions .github/workflows/testsuite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ jobs:
with:
fetch-depth: 1

- name: Set up Python
uses: actions/setup-python@v4
- uses: conda-incubator/setup-miniconda@v3
with:
python-version: ${{ matrix.python }}
mamba-version: "*"
channels: conda-forge,defaults
channel-priority: true
activate-environment: withfftw
environment-file: .github/fftw-env.yaml

- name: Install Test Deps
shell: bash -el {0}
run: |
which pip
python --version
pip install .[tests,fftw]
- name: Run Tests
shell: bash -el {0}
run: |
python -m pytest --cov=powerbox --cov-config=.coveragerc --cov-report xml:./coverage.xml --junitxml=test-reports/xunit.xml
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:


- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.0
rev: 24.4.2
hooks:
- id: black

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ tests = [
"pytest",
"pytest-cov",
"scipy",
"pyfftw"
]
docs = [
"sphinx",
Expand All @@ -53,7 +52,7 @@ docs = [
"packaging", # required for camb
]
dev = [
"powerbox[tests,docs]",
"powerbox[tests,docs,fftw]",
"pre-commit"
]
fftw = [
Expand Down
24 changes: 17 additions & 7 deletions src/powerbox/dft_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from abc import ABC, abstractmethod
from functools import cache
from multiprocessing import cpu_count

try:
import pyfftw
Expand Down Expand Up @@ -79,17 +80,26 @@ class FFTW(FFTBackend):
"""FFT backend using pyfftw."""

def __init__(self, nthreads=None):
if nthreads is None:
from multiprocessing import cpu_count

nthreads = cpu_count()

self.nthreads = nthreads
try:
import pyfftw
except ImportError:
raise ImportError("pyFFTW could not be imported...")

try:
pyfftw.builders._utils._default_threads(4)
except ValueError:
if nthreads and nthreads > 1:
warnings.warn(
"pyFFTW was not installed with multithreading. Using 1 thread.",
stacklevel=2,
)
nthreads = 1

if nthreads is None:
nthreads = cpu_count()

self.nthreads = nthreads

self._fftshift = pyfftw.interfaces.numpy_fft.fftshift
self._ifftshift = pyfftw.interfaces.numpy_fft.ifftshift
self._fftfreq = pyfftw.interfaces.numpy_fft.fftfreq
Expand All @@ -111,7 +121,7 @@ def get_fft_backend(nthreads=None):
Will return the Numpy backend if nthreads is None, otherwise the FFTW backend with
the given number of threads.
"""
if nthreads is None or nthreads > 0:
if nthreads is None or nthreads > 1:
try:
fftbackend = FFTW(nthreads=nthreads)
except ImportError:
Expand Down
31 changes: 27 additions & 4 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import contextlib
import numpy as np

from powerbox.dft import fft, fftfreq, fftshift, ifft, ifftshift
Expand All @@ -14,7 +15,23 @@
(1, 1, 0, 1),
]

BACKENDS = [NumpyFFT(), FFTW(nthreads=1), FFTW(nthreads=2)]
BACKENDS = [
NumpyFFT(),
]

HAVE_FFTW = False
HAVE_FFTW_MULTITHREAD = False

with contextlib.suppress(ValueError, ImportError):
import pyfftw

BACKENDS.append(FFTW(nthreads=1))
HAVE_FFTW = True

pyfftw.builders._utils._default_threads(4)

BACKENDS.append(FFTW(nthreads=2))
HAVE_FFTW_MULTITHREAD = True


def gauss_ft(k, a, b, n=2):
Expand Down Expand Up @@ -116,8 +133,14 @@ def test_mixed_2d_fb(g2d, a, b, ainv, binv, backend):
assert np.max(np.abs(fx.real - analytic_mix(xgrid, a, b, ainv, binv))) < 1e-10


NTHREADS_TO_CHECK = (None, 1, False)

if HAVE_FFTW_MULTITHREAD:
NTHREADS_TO_CHECK += (2,)


@pytest.mark.parametrize("a,b, ainv, binv", ABCOMBOS)
@pytest.mark.parametrize("nthreads", (None, 1, 2, False))
@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK)
def test_mixed_2d_bf(g2d, a, b, ainv, binv, nthreads):
Fk, freq = ifft(g2d["fx"], Lk=g2d["L"], a=ainv, b=binv, nthreads=nthreads)
L = -2 * np.min(freq)
Expand All @@ -127,15 +150,15 @@ def test_mixed_2d_bf(g2d, a, b, ainv, binv, nthreads):
assert np.max(np.abs(fx.real - analytic_mix(xgrid, a, binv, ainv, b))) < 1e-10


@pytest.mark.parametrize("nthreads", (None, 1, 2, False))
@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK)
def test_fftshift(nthreads):
x = np.linspace(0, 1, 11)

y = fftshift(ifftshift(x, nthreads=nthreads), nthreads=nthreads)
assert np.all(x == y)


@pytest.mark.parametrize("nthreads", (None, 1, 2, False))
@pytest.mark.parametrize("nthreads", NTHREADS_TO_CHECK)
@pytest.mark.parametrize("n", (10, 11))
def test_fftfreq(nthreads, n):
freqs = fftfreq(n, nthreads=nthreads)
Expand Down

0 comments on commit 3c87176

Please sign in to comment.