Skip to content

Commit

Permalink
Improve mask mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 15, 2024
1 parent e485390 commit b6bded8
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(
self.y0 = y0
self.dy = 1.0

self.fbp_filter: Optional[snp.Array] = None
self.fbp_mask: Optional[snp.Array] = None

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
Expand Down Expand Up @@ -155,27 +158,28 @@ def fbp(self, y: ArrayLike) -> snp.Array:
Returns:
FBP inverse of projection.
"""

N = y.shape[1]
nvec = jnp.arange(N) - (N - 1) // 2
h = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)

if self.fbp_filter is None:
nvec = jnp.arange(N) - (N - 1) // 2
self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)

if self.fbp_mask is None:
unit_sino = jnp.ones(self.output_shape, dtype=np.float32)
# Threshold is multiplied by 0.99... fudge factor to account for numerical errors
# in back projection.
self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore

# Apply ramp filter in the frequency domain, padding to avoid
# boundary effects
h = self.fbp_filter
hf = jnp.fft.fft(h, n=2 * N - 1, axis=1)
yf = jnp.fft.fft(y, n=2 * N - 1, axis=1)
hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[
:, (N - 1) // 2 : -(N - 1) // 2
].real.astype(jnp.float32)

x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.back_project(hy)
# Mask out the invalid region of the reconstruction
gi, gj = jnp.mgrid[: x.shape[0], : x.shape[1]]
x = jnp.where(
jnp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2,
x,
0.0,
)
x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore
return x

@staticmethod
Expand Down Expand Up @@ -242,7 +246,7 @@ def _project(
@partial(jax.jit, static_argnames=["nx"])
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
) -> snp.Array:
r"""Compute X-ray back projection.
Args:
Expand Down Expand Up @@ -424,7 +428,7 @@ def _project_single(
return proj

@staticmethod
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike:
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array:
r"""
Args:
proj: Input (set of) projection(s).
Expand Down

0 comments on commit b6bded8

Please sign in to comment.