Skip to content

Commit

Permalink
Fixed issue #48
Browse files Browse the repository at this point in the history
Fixed typo in _numpy_fft.py
  • Loading branch information
oleksandr-pavlyk committed Feb 24, 2021
1 parent cc946a1 commit 6fd2760
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mkl_fft/_numpy_fft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
g#!/usr/bin/env python
#!/usr/bin/env python
# Copyright (c) 2017-2019, Intel Corporation
#
# Redistribution and use in source and binary forms, with or without
Expand Down
69 changes: 65 additions & 4 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,48 @@ def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_f
return a


def flat_to_multi(ind, shape):
nd = len(shape)
m_ind = [-1] * nd
j = ind
for i in range(nd):
si = shape[nd-1-i]
q = j // si
r = j - si * q
m_ind[nd-1-i] = r
j = q
return m_ind


def iter_complementary(x, axes, func, kwargs, result):
if axes is None:
return func(x, **kwargs)
x_shape = x.shape
nd = x.ndim
r = list(range(nd))
sl = [slice(None, None, None)] * nd
if not isinstance(axes, tuple):
axes = (axes,)
for ai in axes:
r[ai] = None
size = 1
sub_shape = []
dual_ind = []
for ri in r:
if ri is not None:
size *= x_shape[ri]
sub_shape.append(x_shape[ri])
dual_ind.append(ri)

for ind in range(size):
m_ind = flat_to_multi(ind, sub_shape)
for k1, k2 in zip(dual_ind, m_ind):
sl[k1] = k2
np.copyto(result[tuple(sl)], func(x[tuple(sl)], **kwargs))

return result


def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
"""Perform n-dimensional FFT over all axes"""
cdef int err
Expand Down Expand Up @@ -988,6 +1030,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):

return f_arr


def _check_shapes_for_direct(xs, shape, axes):
if len(axes) > 7: # Intel MKL supports up to 7D
return False
Expand All @@ -1006,6 +1049,14 @@ def _check_shapes_for_direct(xs, shape, axes):
return True


def _output_dtype(dt):
if dt == np.double:
return np.cdouble
if dt == np.single:
return np.csingle
return dt


def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
if direction not in [-1, +1]:
raise ValueError("Direction of FFT should +1 or -1")
Expand All @@ -1026,10 +1077,20 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
if _direct:
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
else:
sc = (<object> fsc)**(1/x.ndim)
return _iter_fftnd(x, s=shape, axes=axes,
overwrite_arg=overwrite_x, scale_function=lambda n: sc,
function=fft if direction == 1 else ifft)
if (shape is None):
x = np.asarray(x)
res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
return iter_complementary(
x, axes,
_direct_fftnd,
{'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc},
res
)
else:
sc = (<object> fsc)**(1/x.ndim)
return _iter_fftnd(x, s=shape, axes=axes,
overwrite_arg=overwrite_x, scale_function=lambda n: sc,
function=fft if direction == 1 else ifft)


def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, forward_scale=1.0):
Expand Down
19 changes: 19 additions & 0 deletions mkl_fft/tests/test_fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ def test_matrix4(self):
assert_allclose(t_strided, t_contig, rtol=r_tol, atol=a_tol)


def test_matrix5(self):
"""fftn of strided array is same as fftn of a contiguous copy"""
rs = rnd.RandomState(1234)
x = rs.randn(6, 11, 12, 13)
y = x[::-2, :, :, ::3]
r_tol, a_tol = _get_rtol_atol(y)
f = mkl_fft.fftn(y, axes=(1,2))
for i0 in range(y.shape[0]):
for i3 in range(y.shape[3]):
assert_allclose(
f[i0, :, :, i3],
mkl_fft.fftn(y[i0, :, : , i3]),
rtol=r_tol, atol=a_tol
)




class Test_Regressions(TestCase):

def setUp(self):
Expand Down Expand Up @@ -129,6 +147,7 @@ def test_rfftn_numpy(self):
tr_rfft = np.transpose(mkl_fft.rfftn_numpy(x, axes=a), a)
assert_allclose(rfft_tr, tr_rfft, rtol=r_tol, atol=a_tol)


class Test_Scales(TestCase):
def setUp(self):
pass
Expand Down

0 comments on commit 6fd2760

Please sign in to comment.