Skip to content

Commit

Permalink
Fix XRayTransform2D projection dtype and docs (#557)
Browse files Browse the repository at this point in the history
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
Michael-T-McCann and bwohlberg authored Oct 15, 2024
1 parent 8dc1a2a commit 008697c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
44 changes: 30 additions & 14 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class XRayTransform2D(LinearOperator):
"""Parallel ray, single axis, 2D X-ray projector.
r"""Parallel ray, single axis, 2D X-ray projector.
This implementation approximates the projection of each rectangular
pixel as a boxcar function (whereas the exact projection is a
Expand All @@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator):
accumulation of pixel values into bins (equivalently, makes the
linear operator sparse).
Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather
than 1) in order to satisfy the aforementioned spacing requirement.
`x0`, `dx`, and `y0` should be expressed in units such that the
detector spacing `dy` is 1.0.
"""
Expand All @@ -64,9 +67,11 @@ def __init__(
corresponds to summing columns, and an angle of pi/4
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `(-input_shape / 2, -input_shape / 2)`.
dx: Image pixel side length in x- and y-direction. Should be
<= 1.0 in each dimension. By default, [1.0, 1.0].
default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
dx: Image pixel side length in x- and y-direction. Must be
set so that the width of a projected pixel is never
larger than 1.0. By default, [:math:`\sqrt{2}/2`,
:math:`\sqrt{2}/2`].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
Expand Down Expand Up @@ -111,25 +116,36 @@ def __init__(

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
output_shape=self.output_shape,
output_dtype=np.float32,
eval_fn=self.project,
adj_fn=self.back_project,
)

def project(self, im: ArrayLike) -> snp.Array:
"""Compute X-ray projection."""
"""Compute X-ray projection, equivalent to `H @ im`.
Args:
im: Input array representing the image to project.
"""
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)

def back_project(self, y: ArrayLike) -> snp.Array:
"""Compute X-ray back projection"""
"""Compute X-ray back projection, equivalent to `H.T @ y`.
Args:
y: Input array representing the sinogram to back project.
"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike
) -> snp.Array:
r"""
r"""Compute X-ray projection.
Args:
im: Input array, (M, N).
x0: (x, y) position of the corner of the pixel im[0,0].
Expand All @@ -146,8 +162,11 @@ def _project(
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# avoid incompatible types in the .add (scatter operation)
weights = weights.astype(im.dtype)

y = (
jnp.zeros((len(angles), ny))
jnp.zeros((len(angles), ny), dtype=im.dtype)
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
)
Expand All @@ -161,7 +180,8 @@ def _project(
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
r"""
r"""Compute X-ray back projection.
Args:
y: Input projection, (num_angles, N).
x0: (x, y) position of the corner of the pixel im[0,0].
Expand Down Expand Up @@ -259,10 +279,6 @@ class XRayTransform3D(LinearOperator):
:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
make these geometry arrays.
"""

def __init__(
Expand All @@ -279,7 +295,7 @@ def __init__(
"""

self.input_shape: Shape = input_shape
self.matrices = matrices
self.matrices = jnp.asarray(matrices, dtype=np.float32)
self.det_shape = det_shape
self.output_shape = (len(matrices), *det_shape)
super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_apply():
def test_apply_adjoint():
im_shape = (12, 13)
num_angles = 10
x = jnp.ones(im_shape)
x = jnp.ones(im_shape, dtype=jnp.float32)

angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

Expand Down Expand Up @@ -81,6 +81,7 @@ def test_3d_scaling():
# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)

# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
Expand Down

0 comments on commit 008697c

Please sign in to comment.