Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 25, 2024
1 parent 95ffea7 commit b67fd5c
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/powerbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# package is not installed
__version__ = "unknown"

from .import_fft import config
from .powerbox import LogNormalPowerBox, PowerBox
from .tools import angular_average, angular_average_nd, get_power
from .import_fft import config
23 changes: 20 additions & 3 deletions src/powerbox/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@

__all__ = ["fft", "ifft", "fftfreq", "fftshift", "ifftshift"]

from .import_fft import config
# To avoid MKL-related bugs, numpy needs to be imported after pyfftw: see https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np

from .import_fft import config


def fft(
X, L=None, Lk=None, a=0, b=2 * np.pi, left_edge=None, axes=None, ret_cubegrid=False, threads=None
X,
L=None,
Lk=None,
a=0,
b=2 * np.pi,
left_edge=None,
axes=None,
ret_cubegrid=False,
threads=None,
):
r"""
Arbitrary-dimension nice Fourier Transform.
Expand Down Expand Up @@ -132,7 +141,15 @@ def fft(


def ifft(
X, Lk=None, L=None, a=0, b=2 * np.pi, axes=None, left_edge=None, ret_cubegrid=False, threads=None,
X,
Lk=None,
L=None,
a=0,
b=2 * np.pi,
axes=None,
left_edge=None,
ret_cubegrid=False,
threads=None,
):
r"""
Arbitrary-dimension nice inverse Fourier Transform.
Expand Down
17 changes: 9 additions & 8 deletions src/powerbox/import_fft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
import numpy as np
import warnings


def config(THREADS=None):
Expand All @@ -8,17 +8,18 @@ def config(THREADS=None):
if THREADS is None:
if THREADS is None:
from multiprocessing import cpu_count

THREADS = cpu_count()
if THREADS > 0:
try:
#warnings.warn("Using pyFFTW with " + str(THREADS) + " threads...")
# warnings.warn("Using pyFFTW with " + str(THREADS) + " threads...")
from pyfftw import empty_aligned as empty
from pyfftw.interfaces.cache import enable, set_keepalive_time
from pyfftw.interfaces.numpy_fft import fftfreq as _fftfreq
from pyfftw.interfaces.numpy_fft import fftn as _fftn
from pyfftw.interfaces.numpy_fft import fftshift as _fftshift
from pyfftw.interfaces.numpy_fft import ifftn as _ifftn
from pyfftw.interfaces.numpy_fft import ifftshift as _ifftshift
from pyfftw import empty_aligned as empty

def fftn(*args, **kwargs):
return _fftn(*args, threads=THREADS, **kwargs)
Expand All @@ -30,20 +31,21 @@ def ifftn(*args, **kwargs):

except ImportError:
HAVE_FFTW = False
#warnings.warn("USE_FFTW set to True but pyFFTW could not be loaded. Make sure pyFFTW is installed properly. Proceeding with numpy...", UserWarning)
# warnings.warn("USE_FFTW set to True but pyFFTW could not be loaded. Make sure pyFFTW is installed properly. Proceeding with numpy...", UserWarning)
from numpy.fft import fftfreq as _fftfreq
from numpy.fft import fftn
from numpy.fft import fftshift as _fftshift
from numpy.fft import ifftn
from numpy.fft import ifftshift as _ifftshift
else:
HAVE_FFTW = False
#warnings.warn("Using numpy FFT...")
# warnings.warn("Using numpy FFT...")
from numpy.fft import fftfreq as _fftfreq
from numpy.fft import fftn
from numpy.fft import fftshift as _fftshift
from numpy.fft import ifftn
from numpy.fft import ifftshift as _ifftshift

empty = np.empty

def fftshift(x, *args, **kwargs):
Expand All @@ -56,7 +58,6 @@ def fftshift(x, *args, **kwargs):

return out * x.unit if hasattr(x, "unit") else out


def ifftshift(x, *args, **kwargs):
"""
The same as numpy except it preserves units (if Astropy quantities are used).
Expand All @@ -67,7 +68,6 @@ def ifftshift(x, *args, **kwargs):

return out * x.unit if hasattr(x, "unit") else out


def fftfreq(N, d=1.0, b=2 * np.pi):
"""
Return fourier frequencies for a box with N cells, using general Fourier convention.
Expand All @@ -88,4 +88,5 @@ def fftfreq(N, d=1.0, b=2 * np.pi):
The N symmetric frequency components of the Fourier transform. Always centred at 0.
"""
return fftshift(_fftfreq(N, d=d)) * (2 * np.pi / b)
return fftn, ifftn, fftfreq, fftshift, ifftshift, empty, HAVE_FFTW

return fftn, ifftn, fftfreq, fftshift, ifftshift, empty, HAVE_FFTW
44 changes: 38 additions & 6 deletions src/powerbox/powerbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import warnings

from . import dft
from .tools import _magnitude_grid
from .import_fft import config
from .tools import _magnitude_grid

# TODO: add hankel-transform version of LogNormal

Expand Down Expand Up @@ -141,7 +141,15 @@ def __init__(
self.fourier_b = b
self.vol_normalised_power = vol_normalised_power
self.V = self.boxlength**self.dim
_,_, self.fftfreq, self.fftshift, self.ifftshift, self.empty, self.HAVE_FFTW = config(threads)
(
_,
_,
self.fftfreq,
self.fftshift,
self.ifftshift,
self.empty,
self.HAVE_FFTW,
) = config(threads)
self.threads = threads
if self.vol_normalised_power:
self.pk = lambda k: pk(k) / self.V
Expand Down Expand Up @@ -232,7 +240,13 @@ def delta_x(self):
dk[...] = self.delta_k()
dk[...] = (
self.V
* dft.ifft(dk, L=self.boxlength, a=self.fourier_a, b=self.fourier_b, threads=self.threads)[0]
* dft.ifft(
dk,
L=self.boxlength,
a=self.fourier_a,
b=self.fourier_b,
threads=self.threads,
)[0]
)
dk = np.real(dk)

Expand Down Expand Up @@ -369,7 +383,13 @@ def correlation_array(self):
pa = self.empty((self.N,) * self.dim)
pa[...] = self.power_array()
return self.V * np.real(
dft.ifft(pa, L=self.boxlength, a=self.fourier_a, b=self.fourier_b, threads=self.threads)[0]
dft.ifft(
pa,
L=self.boxlength,
a=self.fourier_a,
b=self.fourier_b,
threads=self.threads,
)[0]
)

def gaussian_correlation_array(self):
Expand All @@ -381,7 +401,13 @@ def gaussian_power_array(self):
gca = self.empty((self.N,) * self.dim)
gca[...] = self.gaussian_correlation_array()
gpa = np.abs(
dft.fft(gca, L=self.boxlength, a=self.fourier_a, b=self.fourier_b, threads=self.threads)[0]
dft.fft(
gca,
L=self.boxlength,
a=self.fourier_a,
b=self.fourier_b,
threads=self.threads,
)[0]
)
gpa[self.k() == 0] = 0
return gpa
Expand All @@ -404,7 +430,13 @@ def delta_x(self):
dk[...] = self.delta_k()
dk[...] = (
np.sqrt(self.V)
* dft.ifft(dk, L=self.boxlength, a=self.fourier_a, b=self.fourier_b, threads=self.threads)[0]
* dft.ifft(
dk,
L=self.boxlength,
a=self.fourier_a,
b=self.fourier_b,
threads=self.threads,
)[0]
)
dk = np.real(dk)

Expand Down
11 changes: 9 additions & 2 deletions src/powerbox/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from . import dft


def _getbins(bins, coords, log):
mx = coords.max()
if not np.iterable(bins):
Expand Down Expand Up @@ -538,9 +539,15 @@ def get_power(
V = np.prod(boxlength)

# Calculate the n-D power spectrum and align it with the k from powerbox.
FT, freq, k = dft.fft(deltax, L=boxlength, a=a, b=b, ret_cubegrid=True, threads=threads)
FT, freq, k = dft.fft(
deltax, L=boxlength, a=a, b=b, ret_cubegrid=True, threads=threads
)

FT2 = dft.fft(deltax2, L=boxlength, a=a, b=b,threads=threads)[0] if deltax2 is not None else FT
FT2 = (
dft.fft(deltax2, L=boxlength, a=a, b=b, threads=threads)[0]
if deltax2 is not None
else FT
)
P = np.real(FT * np.conj(FT2) / V**2)

if vol_normalised_power:
Expand Down

0 comments on commit b67fd5c

Please sign in to comment.