Skip to content

Commit

Permalink
Make Python plan attributes private (#608)
Browse files Browse the repository at this point in the history
* py: make plan attributes private

* py: add properties as read-only attributes

* py: test properties

* cu+py: make attributes private in cufinufft plan

* cu+py: add properties as readonly attrs

* cu+py: test properties
  • Loading branch information
janden authored Jan 28, 2025
1 parent c28744a commit cc897a5
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 82 deletions.
118 changes: 69 additions & 49 deletions python/cufinufft/cufinufft/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,20 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
self._plan = None

# Setup type bound methods
self.dtype = np.dtype(dtype)
self._dtype = np.dtype(dtype)

if self.dtype == np.complex128:
if self._dtype == np.complex128:
self._make_plan = _make_plan
self._setpts = _set_pts
self._exec_plan = _exec_plan
self._destroy_plan = _destroy_plan
self.real_dtype = np.float64
elif self.dtype == np.complex64:
self._real_dtype = np.float64
elif self._dtype == np.complex64:
self._make_plan = _make_planf
self._setpts = _set_ptsf
self._exec_plan = _exec_planf
self._destroy_plan = _destroy_planf
self.real_dtype = np.float32
self._real_dtype = np.float32
else:
raise TypeError("Expected complex64 or complex128.")

Expand All @@ -118,12 +118,12 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
if dim not in [1, 2, 3]:
raise ValueError("Only dimensions 1, 2, and 3 supported")

self.dim = dim
self.type = nufft_type
self.isign = isign
self.eps = float(eps)
self.n_modes = n_modes
self.n_trans = n_trans
self._dim = dim
self._type = nufft_type
self._isign = isign
self._eps = float(eps)
self._n_modes = n_modes
self._n_trans = n_trans
self._maxbatch = 1 # TODO: optimize this one day

# Get the default option values.
Expand All @@ -146,6 +146,26 @@ def __init__(self, nufft_type, n_modes, n_trans=1, eps=1e-6, isign=None,
# we want to keep around for life of instance.
self._references = []

@property
def type(self):
return self._type

@property
def dtype(self):
return self._dtype

@property
def dim(self):
return self._dim

@property
def n_modes(self):
return self._n_modes

@property
def n_trans(self):
return self._n_trans

@staticmethod
def _default_opts():
"""
Expand Down Expand Up @@ -174,15 +194,15 @@ def _init_plan(self):
# We extend the mode tuple to 3D as needed,
# and reorder from C/python ndarray.shape style input (nZ, nY, nX)
# to the (F) order expected by the low level library (nX, nY, nZ).
_n_modes = self.n_modes[::-1] + (1,) * (3 - self.dim)
_n_modes = self._n_modes[::-1] + (1,) * (3 - self._dim)
_n_modes = (c_int64 * 3)(*_n_modes)

ier = self._make_plan(self.type,
self.dim,
ier = self._make_plan(self._type,
self._dim,
_n_modes,
self.isign,
self.n_trans,
self.eps,
self._isign,
self._n_trans,
self._eps,
byref(self._plan),
self._opts)

Expand All @@ -209,20 +229,20 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
points (source for type 1, target for type 2).
"""

_x = _ensure_array_type(x, "x", self.real_dtype)
_y = _ensure_array_type(y, "y", self.real_dtype)
_z = _ensure_array_type(z, "z", self.real_dtype)
_x = _ensure_array_type(x, "x", self._real_dtype)
_y = _ensure_array_type(y, "y", self._real_dtype)
_z = _ensure_array_type(z, "z", self._real_dtype)

_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self.dim)
_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self._dim)

M = _compat.get_array_size(_x)

if self.type == 3:
_s = _ensure_array_type(s, "s", self.real_dtype)
_t = _ensure_array_type(t, "t", self.real_dtype)
_u = _ensure_array_type(u, "u", self.real_dtype)
if self._type == 3:
_s = _ensure_array_type(s, "s", self._real_dtype)
_t = _ensure_array_type(t, "t", self._real_dtype)
_u = _ensure_array_type(u, "u", self._real_dtype)

_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self.dim)
_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self._dim)

N = _compat.get_array_size(_s)
else:
Expand All @@ -242,22 +262,22 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
# We will also store references to these arrays.
# This keeps python from prematurely cleaning them up.
self._references.append(_x)
if self.dim >= 2:
if self._dim >= 2:
fpts_axes.insert(0, _compat.get_array_ptr(_y))
self._references.append(_y)
if self.dim >= 3:
if self._dim >= 3:
fpts_axes.insert(0, _compat.get_array_ptr(_z))
self._references.append(_z)

# Do the same for type 3
if self.type == 3:
if self._type == 3:
fpts_axes_t3 = [_compat.get_array_ptr(_s), None, None]
self._references.append(_s)
if self.dim >= 2:
if self._dim >= 2:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_t))
self._references.append(_t)

if self.dim >= 3:
if self._dim >= 3:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_u))
self._references.append(_u)
else:
Expand All @@ -268,8 +288,8 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
M, *fpts_axes[:3],
N, *fpts_axes_t3[:3])

self.nj = M
self.nk = N
self._nj = M
self._nk = N

if ier != 0:
raise RuntimeError('Error setting non-uniform points.')
Expand Down Expand Up @@ -297,37 +317,37 @@ def execute(self, data, out=None):
The output array of the transform(s).
"""

_data = _ensure_array_type(data, "data", self.dtype)
_out = _ensure_array_type(out, "out", self.dtype, output=True)
_data = _ensure_array_type(data, "data", self._dtype)
_out = _ensure_array_type(out, "out", self._dtype, output=True)

if self.type == 1:
req_data_shape = (self.n_trans, self.nj)
req_out_shape = self.n_modes
elif self.type == 2:
req_data_shape = (self.n_trans, *self.n_modes)
req_out_shape = (self.nj,)
elif self.type == 3:
req_data_shape = (self.n_trans, self.nj)
req_out_shape = (self.nk,)
if self._type == 1:
req_data_shape = (self._n_trans, self._nj)
req_out_shape = self._n_modes
elif self._type == 2:
req_data_shape = (self._n_trans, *self._n_modes)
req_out_shape = (self._nj,)
elif self._type == 3:
req_data_shape = (self._n_trans, self._nj)
req_out_shape = (self._nk,)

_data, data_shape = _ensure_array_shape(_data, "data", req_data_shape,
allow_reshape=True)
if self.type == 1:
if self._type == 1:
batch_shape = data_shape[:-1]
else:
batch_shape = data_shape[:-self.dim]
batch_shape = data_shape[:-self._dim]

req_out_shape = batch_shape + req_out_shape

if out is None:
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self.dtype)
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self._dtype)
else:
_out = _ensure_array_shape(_out, "out", req_out_shape)

if self.type in [1, 3]:
if self._type in [1, 3]:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_data),
_compat.get_array_ptr(_out))
elif self.type == 2:
elif self._type == 2:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_out),
_compat.get_array_ptr(_data))

Expand Down
30 changes: 30 additions & 0 deletions python/cufinufft/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,33 @@ def test_opts(to_gpu, to_cpu, shape=(8, 8, 8), M=32, tol=1e-3):
fk = to_cpu(fk_gpu)

utils.verify_type1(k, c, fk, tol)


def test_cufinufft_plan_properties():
nufft_type = 2
n_modes = (8, 8)
n_trans = 2
dtype = np.complex64

plan = Plan(nufft_type, n_modes, n_trans, dtype=dtype)

assert plan.type == nufft_type
assert tuple(plan.n_modes) == n_modes
assert plan.dim == len(n_modes)
assert plan.n_trans == n_trans
assert plan.dtype == dtype

with pytest.raises(AttributeError):
plan.type = 1

with pytest.raises(AttributeError):
plan.n_modes = (4, 4)

with pytest.raises(AttributeError):
plan.dim = 1

with pytest.raises(AttributeError):
plan.n_trans = 1

with pytest.raises(AttributeError):
plan.dtype = np.float64
Loading

0 comments on commit cc897a5

Please sign in to comment.