From 008697c5b5a27ea3ea2bd3d6e43d2581089e6df5 Mon Sep 17 00:00:00 2001 From: Mike McCann <57153404+Michael-T-McCann@users.noreply.github.com> Date: Tue, 15 Oct 2024 11:31:30 -0600 Subject: [PATCH] Fix `XRayTransform2D` projection dtype and docs (#557) Co-authored-by: Brendt Wohlberg --- scico/linop/xray/_xray.py | 44 ++++++++++++++++++++---------- scico/test/linop/xray/test_xray.py | 3 +- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 770bf627..4fe40893 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -27,7 +27,7 @@ class XRayTransform2D(LinearOperator): - """Parallel ray, single axis, 2D X-ray projector. + r"""Parallel ray, single axis, 2D X-ray projector. This implementation approximates the projection of each rectangular pixel as a boxcar function (whereas the exact projection is a @@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator): accumulation of pixel values into bins (equivalently, makes the linear operator sparse). + Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather + than 1) in order to satisfy the aforementioned spacing requirement. + `x0`, `dx`, and `y0` should be expressed in units such that the detector spacing `dy` is 1.0. """ @@ -64,9 +67,11 @@ def __init__( corresponds to summing columns, and an angle of pi/4 corresponds to summing along antidiagonals. x0: (x, y) position of the corner of the pixel `im[0,0]`. By - default, `(-input_shape / 2, -input_shape / 2)`. - dx: Image pixel side length in x- and y-direction. Should be - <= 1.0 in each dimension. By default, [1.0, 1.0]. + default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`. + dx: Image pixel side length in x- and y-direction. Must be + set so that the width of a projected pixel is never + larger than 1.0. By default, [:math:`\sqrt{2}/2`, + :math:`\sqrt{2}/2`]. y0: Location of the edge of the first detector bin. By default, `-det_count / 2` det_count: Number of elements in detector. If ``None``, @@ -111,17 +116,27 @@ def __init__( super().__init__( input_shape=self.input_shape, + input_dtype=np.float32, output_shape=self.output_shape, + output_dtype=np.float32, eval_fn=self.project, adj_fn=self.back_project, ) def project(self, im: ArrayLike) -> snp.Array: - """Compute X-ray projection.""" + """Compute X-ray projection, equivalent to `H @ im`. + + Args: + im: Input array representing the image to project. + """ return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) def back_project(self, y: ArrayLike) -> snp.Array: - """Compute X-ray back projection""" + """Compute X-ray back projection, equivalent to `H.T @ y`. + + Args: + y: Input array representing the sinogram to back project. + """ return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) @staticmethod @@ -129,7 +144,8 @@ def back_project(self, y: ArrayLike) -> snp.Array: def _project( im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike ) -> snp.Array: - r""" + r"""Compute X-ray projection. + Args: im: Input array, (M, N). x0: (x, y) position of the corner of the pixel im[0,0]. @@ -146,8 +162,11 @@ def _project( # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. inds = jnp.where(inds >= 0, inds, ny) + # avoid incompatible types in the .add (scatter operation) + weights = weights.astype(im.dtype) + y = ( - jnp.zeros((len(angles), ny)) + jnp.zeros((len(angles), ny), dtype=im.dtype) .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] .add(im * weights) ) @@ -161,7 +180,8 @@ def _project( def _back_project( y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike ) -> ArrayLike: - r""" + r"""Compute X-ray back projection. + Args: y: Input projection, (num_angles, N). x0: (x, y) position of the corner of the pixel im[0,0]. @@ -259,10 +279,6 @@ class XRayTransform3D(LinearOperator): :meth:`XRayTransform3D.matrices_from_euler_angles` can help to make these geometry arrays. - - - - """ def __init__( @@ -279,7 +295,7 @@ def __init__( """ self.input_shape: Shape = input_shape - self.matrices = matrices + self.matrices = jnp.asarray(matrices, dtype=np.float32) self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) super().__init__( diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index cd7c0dcd..b9e12776 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -49,7 +49,7 @@ def test_apply(): def test_apply_adjoint(): im_shape = (12, 13) num_angles = 10 - x = jnp.ones(im_shape) + x = jnp.ones(im_shape, dtype=jnp.float32) angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) @@ -81,6 +81,7 @@ def test_3d_scaling(): # default spacing M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) + # fmt: off truth = jnp.array( [[[0.0, 0.0, 0.0, 0.0],