Skip to content

Commit

Permalink
ENH: add dtype kwarg to fft.{fftfreq, rfftfreq}
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Feb 1, 2025
1 parent 1a288de commit 60d44e3
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions array_api_strict/_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,26 +251,52 @@ def ihfft(
return res

@requires_extension('fft')
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
def fftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[dtype] = None,

Check failure on line 259 in array_api_strict/_fft.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

array_api_strict/_fft.py:259:21: F821 Undefined name `dtype`
device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
See its docstring for more information.
"""
if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.fft.fftfreq(n, d=d), device=device)
if dtype and not dtype in _real_floating_dtypes:
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")

np_result = np.fft.fftfreq(n, d=d)
if dtype:
np_result = np_result.astype(dtype._np_dtype)
return Array._new(np_result, device=device)

@requires_extension('fft')
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
def rfftfreq(
n: int,
/,
*,
d: float = 1.0,
dtype: Optional[dtype] = None,

Check failure on line 283 in array_api_strict/_fft.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

array_api_strict/_fft.py:283:21: F821 Undefined name `dtype`
device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
See its docstring for more information.
"""
if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.fft.rfftfreq(n, d=d), device=device)
if dtype and not dtype in _real_floating_dtypes:
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")

np_result = np.fft.rfftfreq(n, d=d)
if dtype:
np_result = np_result.astype(dtype._np_dtype)
return Array._new(np_result, device=device)

@requires_extension('fft')
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
Expand Down

0 comments on commit 60d44e3

Please sign in to comment.