diff --git a/CHANGES.rst b/CHANGES.rst
index f0dbfd7c..3354e735 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -3,10 +3,20 @@ SCICO Release Notes
===================
-Version 0.0.6 (unreleased)
+Version 0.0.7 (unreleased)
+----------------------------
+
+• No changes yet.
+
+
+
+Version 0.0.6 (2024-10-25)
----------------------------
• Significant changes to ``linop.xray.astra`` API.
+• Rename integrated 2D X-ray transform class to
+ ``linop.xray.XRayTransform2D`` and add filtered back projection method
+ ``fbp``.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
@@ -18,8 +28,8 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
-• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31.
-• Support ``flax`` versions 0.8.0 to 0.8.3.
+• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.4.35.
+• Support ``flax`` versions 0.8.0 to 0.10.0.
diff --git a/MANIFEST.in b/MANIFEST.in
index 9cbc39a6..d9ee9f5f 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -8,10 +8,10 @@ include pyproject.toml
include pytest.ini
include requirements.txt
include dev_requirements.txt
-include examples/scriptcheck.sh
include docs/docs_requirements.txt
recursive-include scico *.py
-recursive-include scico/data *.png *.npz
-recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.ico
-recursive-include examples *_requirements.txt *.txt *.rst *.py
+recursive-include scico/data *.png *.mpk *.rst
+recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.png *.ico
+recursive-include examples *_requirements.txt *.txt *.rst *.py *.sh
+recursive-include misc *.py *.sh *.rst
diff --git a/README.md b/README.md
index e1caf27d..cd333e64 100644
--- a/README.md
+++ b/README.md
@@ -32,11 +32,12 @@ this software for published work, please cite the corresponding [JOSS
Paper](https://doi.org/10.21105/joss.04722) (see bibtex entry
`balke-2022-scico` in `docs/source/references.bib`).
+
# Installation
-See the [online
-documentation](https://scico.rtfd.io/en/latest/install.html) for
-installation instructions.
+The online documentation includes detailed
+[installation instructions](https://scico.rtfd.io/en/latest/install.html).
+
# Usage Examples
@@ -47,8 +48,11 @@ Jupyter Notebooks are provided in the
to `examples/notebooks`. They are also viewable on
[GitHub](https://github.com/lanl/scico-data/tree/main/notebooks) or
[nbviewer](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb),
-or can be run online by
-[binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb).
+and can be run online on
+[binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb)
+or
+[google colab](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb).
+
# License
diff --git a/data b/data
index c1233896..b186bddd 160000
--- a/data
+++ b/data
@@ -1 +1 @@
-Subproject commit c12338966b1b9f92554066743b1a8b664c7b7e24
+Subproject commit b186bddd170ded03be04e7921f5d86d24c92c54f
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
index 58ba847f..b97a52bd 100644
--- a/docs/source/examples.rst
+++ b/docs/source/examples.rst
@@ -34,12 +34,11 @@ Computed Tomography
examples/ct_svmbir_ppp_bm3d_admm_cg
examples/ct_svmbir_ppp_bm3d_admm_prox
examples/ct_fan_svmbir_ppp_bm3d_admm_prox
- examples/ct_astra_modl_train_foam2
- examples/ct_astra_odp_train_foam2
- examples/ct_astra_unet_train_foam2
+ examples/ct_modl_train_foam2
+ examples/ct_odp_train_foam2
+ examples/ct_unet_train_foam2
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
- examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm
Deconvolution
@@ -96,7 +95,7 @@ Miscellaneous
examples/denoise_dncnn_universal
examples/diffusercam_tv_admm
examples/video_rpca_admm
- examples/ct_astra_datagen_foam2
+ examples/ct_datagen_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/denoise_datagen_bsds
@@ -181,10 +180,10 @@ Machine Learning
.. toctree::
:maxdepth: 1
- examples/ct_astra_datagen_foam2
- examples/ct_astra_modl_train_foam2
- examples/ct_astra_odp_train_foam2
- examples/ct_astra_unet_train_foam2
+ examples/ct_datagen_foam2
+ examples/ct_modl_train_foam2
+ examples/ct_odp_train_foam2
+ examples/ct_unet_train_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/deconv_modl_train_foam1
diff --git a/docs/source/references.bib b/docs/source/references.bib
index 257f2428..e612e36e 100644
--- a/docs/source/references.bib
+++ b/docs/source/references.bib
@@ -396,6 +396,13 @@ @Article {jin-2017-unet
doi = {10.1109/TIP.2017.2713099}
}
+@Book {kak-1988-principles,
+ author = {Avinash C. Kak and Malcolm Slaney},
+ title = {Principles of Computerized Tomographic Imaging},
+ publisher = {IEEE Press},
+ year = 1988
+}
+
@TechReport {kamilov-2016-minimizing,
author = {Ulugbek S. Kamilov},
title = {Minimizing Isotropic Total Variation without
@@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn
pages = {3142--3155}
}
+
@Article {zhang-2021-plug,
author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and
Zhang, Lei and Van Gool, Luc and Timofte, Radu},
diff --git a/examples/jnb.py b/examples/jnb.py
index 838fed95..6fe58b98 100644
--- a/examples/jnb.py
+++ b/examples/jnb.py
@@ -62,10 +62,10 @@ def py_file_to_string(src):
# Process remainder of source file
for line in srcfile:
- if re.match("^input\(", line): # end processing when input statement encountered
+ if re.match(r"^input\(", line): # end processing when input statement encountered
break
line = re.sub('^r"""', '"""', line) # remove r from r"""
- line = re.sub(":cite:\`([^`]+)\`", r'', line) # fix cite format
+ line = re.sub(r":cite:\`([^`]+)\`", r'', line) # fix cite format
lines.append(line)
# Backtrack through list of lines to remove trailing newlines
diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst
index 446186a7..95e11f7c 100644
--- a/examples/scripts/README.rst
+++ b/examples/scripts/README.rst
@@ -33,11 +33,11 @@ Computed Tomography
PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)
`ct_fan_svmbir_ppp_bm3d_admm_prox.py `_
PPP (with BM3D) Fan-Beam CT Reconstruction
- `ct_astra_modl_train_foam2.py `_
- CT Training and Reconstructions with MoDL
- `ct_astra_odp_train_foam2.py `_
- CT Training and Reconstructions with ODP
- `ct_astra_unet_train_foam2.py `_
+ `ct_modl_train_foam2.py `_
+ CT Training and Reconstruction with MoDL
+ `ct_odp_train_foam2.py `_
+ CT Training and Reconstruction with ODP
+ `ct_unet_train_foam2.py `_
CT Training and Reconstructions with UNet
`ct_projector_comparison_2d.py `_
2D X-ray Transform Comparison
@@ -123,7 +123,7 @@ Miscellaneous
TV-Regularized 3D DiffuserCam Reconstruction
`video_rpca_admm.py `_
Video Decomposition via Robust PCA
- `ct_astra_datagen_foam2.py `_
+ `ct_datagen_foam2.py `_
CT Data Generation for NN Training
`deconv_datagen_bsds.py `_
Blurred Data Generation (Natural Images) for NN Training
@@ -239,13 +239,13 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^
- `ct_astra_datagen_foam2.py `_
+ `ct_datagen_foam2.py `_
CT Data Generation for NN Training
- `ct_astra_modl_train_foam2.py `_
- CT Training and Reconstructions with MoDL
- `ct_astra_odp_train_foam2.py `_
- CT Training and Reconstructions with ODP
- `ct_astra_unet_train_foam2.py `_
+ `ct_modl_train_foam2.py `_
+ CT Training and Reconstruction with MoDL
+ `ct_odp_train_foam2.py `_
+ CT Training and Reconstruction with ODP
+ `ct_unet_train_foam2.py `_
CT Training and Reconstructions with UNet
`deconv_datagen_bsds.py `_
Blurred Data Generation (Natural Images) for NN Training
diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py
index b3576fda..9c462cd0 100644
--- a/examples/scripts/ct_astra_3d_tv_admm.py
+++ b/examples/scripts/ct_astra_3d_tv_admm.py
@@ -44,7 +44,7 @@
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))
n_projection = 10 # number of projections
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
C = XRayTransform3D(
tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles
) # CT projection operator
diff --git a/examples/scripts/ct_astra_3d_tv_padmm.py b/examples/scripts/ct_astra_3d_tv_padmm.py
index c6c09007..3b76a8be 100644
--- a/examples/scripts/ct_astra_3d_tv_padmm.py
+++ b/examples/scripts/ct_astra_3d_tv_padmm.py
@@ -44,7 +44,7 @@
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))
n_projection = 10 # number of projections
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_spacing = [1.0, 1.0]
det_count = [Nz, max(Nx, Ny)]
vectors = angle_to_vector(det_spacing, angles)
@@ -56,7 +56,7 @@
y = C @ tangle # sinogram
-"""
+r"""
Set up problem and solver. We want to minimize the functional
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x}
diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py
index 362c8cc3..a9dab965 100644
--- a/examples/scripts/ct_astra_noreg_pcg.py
+++ b/examples/scripts/ct_astra_noreg_pcg.py
@@ -44,7 +44,7 @@
Configure a CT projection operator and generate synthetic measurements.
"""
n_projection = N # matches the phantom size so this is not few-view CT
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y = A @ x_gt # sinogram
diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py
index 5349311c..fd684e3d 100644
--- a/examples/scripts/ct_astra_tv_admm.py
+++ b/examples/scripts/ct_astra_tv_admm.py
@@ -38,21 +38,22 @@
"""
N = 512 # phantom size
np.random.seed(1234)
-x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
-x_gt = snp.array(x_gt) # convert to jax type
+x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))
"""
Configure CT projection operator and generate synthetic measurements.
"""
n_projection = 45 # number of projections
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
+det_count = int(N * 1.05 / np.sqrt(2.0))
+det_spacing = np.sqrt(2)
+A = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles) # CT projection operator
y = A @ x_gt # sinogram
"""
-Set up ADMM solver object.
+Set up problem functional and ADMM solver object.
"""
λ = 2e0 # ℓ1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
@@ -65,9 +66,7 @@
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()
-
f = loss.SquaredL2Loss(y=y, A=A)
-
x0 = snp.clip(A.fbp(y), 0, 1.0)
solver = ADMM(
diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py
index b3c285cb..eb749373 100644
--- a/examples/scripts/ct_astra_weighted_tv_admm.py
+++ b/examples/scripts/ct_astra_weighted_tv_admm.py
@@ -35,7 +35,6 @@
Create a ground truth image.
"""
N = 512 # phantom size
-
np.random.seed(0)
x_gt = discrete_phantom(Soil(porosity=0.80), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
@@ -49,8 +48,7 @@
n_projection = 360 # number of projections
Io = 1e3 # source flux
𝛼 = 1e-2 # attenuation coefficient
-
-angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y_c = A @ x_gt # sinogram
@@ -99,13 +97,10 @@ def postprocess(x):
# shown here).
ρ = 2.5e3 # ADMM penalty parameter
lambda_unweighted = 3e2 # regularization strength
-
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration
-
f = loss.SquaredL2Loss(y=y, A=A)
-
admm_unweighted = ADMM(
f=f,
g_list=[lambda_unweighted * functional.L21Norm()],
@@ -137,10 +132,8 @@ def postprocess(x):
$I_0$ changes.
"""
lambda_weighted = 5e1
-
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
-
admm_weighted = ADMM(
f=f,
g_list=[lambda_weighted * functional.L21Norm()],
@@ -151,6 +144,7 @@ def postprocess(x):
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
+print()
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)
diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_datagen_foam2.py
similarity index 93%
rename from examples/scripts/ct_astra_datagen_foam2.py
rename to examples/scripts/ct_datagen_foam2.py
index 4e6fb97c..2fc9be59 100644
--- a/examples/scripts/ct_astra_datagen_foam2.py
+++ b/examples/scripts/ct_datagen_foam2.py
@@ -14,6 +14,7 @@
"""
# isort: off
+import os
import numpy as np
import logging
@@ -21,6 +22,9 @@
ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087
+# Set an arbitrary processor count (only applies if GPU is not available).
+os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
+
from scico import plot
from scico.flax.examples import load_ct_data
diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_modl_train_foam2.py
similarity index 95%
rename from examples/scripts/ct_astra_modl_train_foam2.py
rename to examples/scripts/ct_modl_train_foam2.py
index a06d7b81..10132214 100644
--- a/examples/scripts/ct_astra_modl_train_foam2.py
+++ b/examples/scripts/ct_modl_train_foam2.py
@@ -5,8 +5,8 @@
# with the package.
r"""
-CT Training and Reconstructions with MoDL
-=========================================
+CT Training and Reconstruction with MoDL
+========================================
This example demonstrates the training and application of a
model-based deep learning (MoDL) architecture described in
@@ -65,7 +65,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.xray.astra import XRayTransform2D
+from scico.linop.xray import XRayTransform2D
"""
Prepare parallel processing. Set an arbitrary processor count (only
@@ -89,16 +89,17 @@
"""
-Build CT projection operator.
+Build CT projection operator. Parameters are chosen so that the operator
+is equivalent to the one used to generate the training data.
"""
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
- det_spacing=1,
- det_count=N,
angles=angles,
-) # CT projection operator
-A = (1.0 / N) * A # normalized
+ det_count=int(N * 1.05 / np.sqrt(2.0)),
+ dx=1.0 / np.sqrt(2),
+)
+A = (1.0 / N) * A # normalize projection operator
"""
diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py
index 95711679..df72e387 100644
--- a/examples/scripts/ct_multi_tv_admm.py
+++ b/examples/scripts/ct_multi_tv_admm.py
@@ -38,15 +38,14 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))
-det_count = N
-det_spacing = np.sqrt(2)
-
"""
Define CT geometry and construct array of (approximately) equivalent projectors.
"""
n_projection = 45 # number of projections
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
+det_count = int(N * 1.05 / np.sqrt(2.0))
+det_spacing = np.sqrt(2)
projectors = {
"astra": astra.XRayTransform2D(
x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0
diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_odp_train_foam2.py
similarity index 93%
rename from examples/scripts/ct_astra_odp_train_foam2.py
rename to examples/scripts/ct_odp_train_foam2.py
index 03753c1b..cad279bb 100644
--- a/examples/scripts/ct_astra_odp_train_foam2.py
+++ b/examples/scripts/ct_odp_train_foam2.py
@@ -5,8 +5,8 @@
# with the package.
r"""
-CT Training and Reconstructions with ODP
-========================================
+CT Training and Reconstruction with ODP
+=======================================
This example demonstrates the training of the unrolled optimization with
deep priors (ODP) gradient descent architecture described in
@@ -72,7 +72,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.xray.astra import XRayTransform2D
+from scico.linop.xray import XRayTransform2D
platform = get_backend().platform
@@ -92,21 +92,22 @@
"""
-Build CT projection operator.
+Build CT projection operator. Parameters are chosen so that the operator
+is equivalent to the one used to generate the training data.
"""
-angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
- det_spacing=1,
- det_count=N,
angles=angles,
-) # CT projection operator
-A = (1.0 / N) * A # normalized
+ det_count=int(N * 1.05 / np.sqrt(2.0)),
+ dx=1.0 / np.sqrt(2),
+)
+A = (1.0 / N) * A # normalize projection operator
"""
Build training and testing structures. Inputs are the sinograms and
-outpus are the original generated foams. Keep training and testing
+outputs are the original generated foams. Keep training and testing
partitions.
"""
numtr = 320
diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py
index 2e5d02d3..0a47c7b0 100644
--- a/examples/scripts/ct_projector_comparison_2d.py
+++ b/examples/scripts/ct_projector_comparison_2d.py
@@ -29,9 +29,6 @@
Create a ground truth image.
"""
N = 512
-
-det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))
-
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jnp.array(x_gt)
@@ -41,17 +38,18 @@
"""
num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
+det_count = int(N * 1.02 / jnp.sqrt(2.0))
timer = Timer()
projectors = {}
timer.start("scico_init")
-projectors["scico"] = XRayTransform2D((N, N), angles)
+projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count)
timer.stop("scico_init")
timer.start("astra_init")
projectors["astra"] = astra.XRayTransform2D(
- (N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0
+ (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0
)
timer.stop("astra_init")
diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py
index 8592b44f..e152c6ff 100644
--- a/examples/scripts/ct_svmbir_tv_multi.py
+++ b/examples/scripts/ct_svmbir_tv_multi.py
@@ -4,7 +4,7 @@
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.
-"""
+r"""
TV-Regularized CT Reconstruction (Multiple Algorithms)
======================================================
@@ -51,7 +51,7 @@
"""
num_angles = int(N / 2)
num_channels = N
-angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
+angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt
@@ -87,12 +87,9 @@
"""
x0 = snp.array(x_mrf)
weights = snp.array(weights)
-
λ = 1e-1 # ℓ1 norm regularization parameter
-
f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5)
g = λ * functional.L21Norm() # regularization functional
-
# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
@@ -112,6 +109,7 @@
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
+print("ADMM:")
x_admm = solve_admm.solve()
hist_admm = solve_admm.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n")
@@ -130,6 +128,7 @@
maxiter=50,
itstat_options={"display": True, "period": 10},
)
+print("Linearized ADMM:")
x_ladmm = solver_ladmm.solve()
hist_ladmm = solver_ladmm.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n")
@@ -148,6 +147,7 @@
maxiter=50,
itstat_options={"display": True, "period": 10},
)
+print("PDHG:")
x_pdhg = solver_pdhg.solve()
hist_pdhg = solver_pdhg.itstat_object.history(transpose=True)
print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n")
diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py
index ec48d4ea..c66a802b 100644
--- a/examples/scripts/ct_tv_admm.py
+++ b/examples/scripts/ct_tv_admm.py
@@ -45,15 +45,19 @@
Configure CT projection operator and generate synthetic measurements.
"""
n_projection = 45 # number of projections
-angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
-A = XRayTransform2D((N, N), angles) # CT projection operator
+angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
+det_count = int(N * 1.05 / np.sqrt(2.0))
+dx = 1.0 / np.sqrt(2)
+A = XRayTransform2D(
+ (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx
+) # CT projection operator
y = A @ x_gt # sinogram
"""
-Set up ADMM solver object.
+Set up problem functional and ADMM solver object.
"""
-λ = 2e0 # L1 norm regularization parameter
+λ = 2e0 # ℓ1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
@@ -64,10 +68,8 @@
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()
-
f = loss.SquaredL2Loss(y=y, A=A)
-
-x0 = snp.clip(A.T(y), 0, 1.0)
+x0 = snp.clip(A.fbp(y), 0, 1.0)
solver = ADMM(
f=f,
@@ -94,18 +96,26 @@
Show the recovered image.
"""
-fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5))
+fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0])
+plot.imview(
+ x0,
+ title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f"
+ % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)),
+ cbar=None,
+ fig=fig,
+ ax=ax[1],
+)
plot.imview(
x_reconstruction,
title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),
fig=fig,
- ax=ax[1],
+ ax=ax[2],
)
-divider = make_axes_locatable(ax[1])
+divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
-fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
+fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_unet_train_foam2.py
similarity index 100%
rename from examples/scripts/ct_astra_unet_train_foam2.py
rename to examples/scripts/ct_unet_train_foam2.py
diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst
index 4f05ba2f..e36e9fbd 100644
--- a/examples/scripts/index.rst
+++ b/examples/scripts/index.rst
@@ -21,9 +21,9 @@ Computed Tomography
- ct_svmbir_ppp_bm3d_admm_cg.py
- ct_svmbir_ppp_bm3d_admm_prox.py
- ct_fan_svmbir_ppp_bm3d_admm_prox.py
- - ct_astra_modl_train_foam2.py
- - ct_astra_odp_train_foam2.py
- - ct_astra_unet_train_foam2.py
+ - ct_modl_train_foam2.py
+ - ct_odp_train_foam2.py
+ - ct_unet_train_foam2.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py
@@ -73,7 +73,7 @@ Miscellaneous
- denoise_dncnn_universal.py
- diffusercam_tv_admm.py
- video_rpca_admm.py
- - ct_astra_datagen_foam2.py
+ - ct_datagen_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- denoise_datagen_bsds.py
@@ -143,10 +143,10 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^
- - ct_astra_datagen_foam2.py
- - ct_astra_modl_train_foam2.py
- - ct_astra_odp_train_foam2.py
- - ct_astra_unet_train_foam2.py
+ - ct_datagen_foam2.py
+ - ct_modl_train_foam2.py
+ - ct_odp_train_foam2.py
+ - ct_unet_train_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- deconv_modl_train_foam1.py
diff --git a/misc/conda/install_conda.sh b/misc/conda/install_conda.sh
index 73defcaa..744d1836 100755
--- a/misc/conda/install_conda.sh
+++ b/misc/conda/install_conda.sh
@@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh
export PATH="$CONDAHOME/bin:$PATH"
hash -r
conda config --set always_yes yes
-conda install mamba -n base -c conda-forge
conda update -q conda
conda info -a
diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh
index b5aa4411..cab2b7e2 100755
--- a/misc/conda/make_conda_env.sh
+++ b/misc/conda/make_conda_env.sh
@@ -50,7 +50,7 @@ EOF
)
# Requirements that cannot be installed via conda (i.e. have to use pip)
NOCONDA=$(cat <<-EOF
-flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
+flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
EOF
)
@@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init`
conda activate $ENVNM # Q: why not `source activate`? A: not always in the path
# Add conda-forge channel
-conda config --env --append channels conda-forge
-
-# Install mamba
-conda install mamba -n base -c conda-forge
+conda config --append channels conda-forge
# Install required conda packages (and extra useful packages)
-mamba install $CONDA_FLAGS $CONDAREQ ipython
+conda install $CONDA_FLAGS $CONDAREQ ipython
# Utility ffmpeg is required by imageio for reading mp4 video files
# it can also be installed via the system package manager, .e.g.
# sudo apt install ffmpeg
if [ "$(which ffmpeg)" = '' ]; then
- mamba install $CONDA_FLAGS ffmpeg
+ conda install $CONDA_FLAGS ffmpeg
fi
# Install jaxlib and jax
diff --git a/requirements.txt b/requirements.txt
index 1b0e5359..838a6b18 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
-jaxlib>=0.4.3,<=0.4.33
-jax>=0.4.3,<=0.4.33
+jaxlib>=0.4.13,<=0.4.35
+jax>=0.4.13,<=0.4.35
orbax-checkpoint>=0.5.0
-flax>=0.8.0,<=0.9.0
+flax>=0.8.0,<=0.10.0
pyabel>=0.9.0
diff --git a/scico/__init__.py b/scico/__init__.py
index c4050238..d71ac2a7 100644
--- a/scico/__init__.py
+++ b/scico/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2021-2023 by SCICO Developers
+# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
@@ -8,7 +8,7 @@
solving the inverse problems that arise in scientific imaging applications.
"""
-__version__ = "0.0.6.dev0"
+__version__ = "0.0.7.dev0"
import logging
import sys
@@ -16,14 +16,7 @@
# isort: off
# Suppress jax device warning. See https://github.com/google/jax/issues/6805
-# This only works for jax>0.3.23; for earlier versions, the getLogger
-# argument should be "absl". Two filters are included here due to a change
-# in jax between versions 0.4.2 and 0.4.8, both of which are supported by
-# scico.
-logging.getLogger("jax._src.lib.xla_bridge").addFilter( # jax 0.4.2
- logging.Filter("No GPU/TPU found, falling back to CPU.")
-)
-logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8
+logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8 and later
logging.Filter("No GPU/TPU found, falling back to CPU.")
)
diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py
index e7a99030..e89f10cf 100644
--- a/scico/flax/examples/data_generation.py
+++ b/scico/flax/examples/data_generation.py
@@ -51,16 +51,9 @@ class UnitCircle:
from jax.lib.xla_bridge import get_backend
from scico.linop import CircularConvolve
+from scico.linop.xray import XRayTransform2D
from scico.numpy import Array
-try:
- import astra # noqa: F401
-except ImportError:
- have_astra = False
-else:
- have_astra = True
- from scico.linop.xray.astra import XRayTransform2D
-
class Foam2(UnitCircle):
"""Foam-like material with two attenuations.
@@ -218,10 +211,8 @@ def generate_ct_data(
- **sino** : (:class:`jax.Array`): Corresponding sinograms.
- **fbp** : (:class:`jax.Array`) Corresponding filtered back projections.
"""
- if not (have_ray and have_xdesign and have_astra):
- raise RuntimeError(
- "Packages ray, xdesign, and astra are required for use of this function."
- )
+ if not (have_ray and have_xdesign):
+ raise RuntimeError("Packages ray and xdesign are required for use of this function.")
# Generate input data.
start_time = time()
@@ -234,17 +225,17 @@ def generate_ct_data(
# Configure a CT projection operator to generate synthetic measurements.
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
- gt_sh = (size, size)
- detector_spacing = 1.0
- A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator
-
+ gt_shape = (size, size)
+ dx = 1.0 / np.sqrt(2)
+ det_count = int(size * 1.05 / np.sqrt(2.0))
+ A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count)
# Compute sinograms in parallel.
start_time = time()
if nproc > 1:
# shard array
imgshd = img.reshape((nproc, -1, size, size, 1))
sinoshd = batched_f(A, imgshd)
- sino = sinoshd.reshape((-1, nproj, size, 1))
+ sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1))
else:
sino = vector_f(A, img)
@@ -261,8 +252,8 @@ def generate_ct_data(
# Normalize sinogram.
sino = sino / size
- # Shift FBP to [0,1] range.
- fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())
+ # Clip FBP to [0,1] range.
+ fbp = np.clip(fbp, 0, 1)
if verbose: # pragma: no cover
platform = get_backend().platform
diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py
index 256791a7..1cbef6fc 100644
--- a/scico/functional/_functional.py
+++ b/scico/functional/_functional.py
@@ -117,7 +117,7 @@ def conj_prox(
return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs)
def grad(self, x: Union[Array, BlockArray]):
- r"""Evaluates the gradient of this functional at :math:`\mb{x}`.
+ r"""Evaluate the gradient of this functional at :math:`\mb{x}`.
Args:
x: Point at which to evaluate gradient.
diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py
index 770bf627..9d459db5 100644
--- a/scico/linop/xray/_xray.py
+++ b/scico/linop/xray/_xray.py
@@ -9,7 +9,7 @@
from functools import partial
-from typing import Optional
+from typing import Optional, Tuple
from warnings import warn
import numpy as np
@@ -27,7 +27,7 @@
class XRayTransform2D(LinearOperator):
- """Parallel ray, single axis, 2D X-ray projector.
+ r"""Parallel ray, single axis, 2D X-ray projector.
This implementation approximates the projection of each rectangular
pixel as a boxcar function (whereas the exact projection is a
@@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator):
accumulation of pixel values into bins (equivalently, makes the
linear operator sparse).
+ Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather
+ than 1) in order to satisfy the aforementioned spacing requirement.
+
`x0`, `dx`, and `y0` should be expressed in units such that the
detector spacing `dy` is 1.0.
"""
@@ -64,9 +67,11 @@ def __init__(
corresponds to summing columns, and an angle of pi/4
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
- default, `(-input_shape / 2, -input_shape / 2)`.
- dx: Image pixel side length in x- and y-direction. Should be
- <= 1.0 in each dimension. By default, [1.0, 1.0].
+ default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
+ dx: Image pixel side length in x- and y-direction (axis 0 and
+ 1 respectively). Must be set so that the width of a
+ projected pixel is never larger than 1.0. By default,
+ [:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
@@ -109,27 +114,106 @@ def __init__(
self.y0 = y0
self.dy = 1.0
+ self.fbp_filter: Optional[snp.Array] = None
+ self.fbp_mask: Optional[snp.Array] = None
+
super().__init__(
input_shape=self.input_shape,
+ input_dtype=np.float32,
output_shape=self.output_shape,
+ output_dtype=np.float32,
eval_fn=self.project,
adj_fn=self.back_project,
)
def project(self, im: ArrayLike) -> snp.Array:
- """Compute X-ray projection."""
+ """Compute X-ray projection, equivalent to `H @ im`.
+
+ Args:
+ im: Input array representing the image to project.
+ """
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)
def back_project(self, y: ArrayLike) -> snp.Array:
- """Compute X-ray back projection"""
+ """Compute X-ray back projection, equivalent to `H.T @ y`.
+
+ Args:
+ y: Input array representing the sinogram to back project.
+ """
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)
+ def fbp(self, y: ArrayLike) -> snp.Array:
+ r"""Compute filtered back projection (FBP) inverse of projection.
+
+ Compute the filtered back projection inverse by filtering each
+ row of the sinogram with the filter defined in (61) in
+ :cite:`kak-1988-principles` and then back projecting. The
+ projection angles are assumed to be evenly spaced in
+ :math:`[0, \pi)`; reconstruction quality may be poor if
+ this assumption is violated. Poor quality reconstructions should
+ also be expected when `dx[0]` and `dx[1]` are not equal.
+
+ Args:
+ y: Input projection, (num_angles, N).
+
+ Returns:
+ FBP inverse of projection.
+ """
+ N = y.shape[1]
+
+ if self.fbp_filter is None:
+ nvec = jnp.arange(N) - (N - 1) // 2
+ self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1)
+
+ if self.fbp_mask is None:
+ unit_sino = jnp.ones(self.output_shape, dtype=np.float32)
+ # Threshold is multiplied by 0.99... fudge factor to account for numerical errors
+ # in back projection.
+ self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore
+
+ # Apply ramp filter in the frequency domain, padding to avoid
+ # boundary effects
+ h = self.fbp_filter
+ hf = jnp.fft.fft(h, n=2 * N - 1, axis=1)
+ yf = jnp.fft.fft(y, n=2 * N - 1, axis=1)
+ hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[
+ :, (N - 1) // 2 : -(N - 1) // 2
+ ].real.astype(jnp.float32)
+
+ x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore
+ return x
+
+ @staticmethod
+ def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array:
+ """Compute coefficients of ramp filter used in FBP.
+
+ Compute coefficients of ramp filter used in FBP, as defined in
+ (61) in :cite:`kak-1988-principles`.
+
+ Args:
+ x: Sampling locations at which to compute filter coefficients.
+ tau: Sampling rate.
+
+ Returns:
+ Spatial-domain coefficients of ramp filter.
+ """
+ # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0)
+ # is included to avoid division by zero warnings when x == 1
+ # since np.where evaluates all values for both True and False
+ # branches.
+ return jnp.where(
+ x == 0,
+ 1.0 / (4.0 * tau**2),
+ jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0),
+ )
+
@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike
) -> snp.Array:
- r"""
+ r"""Compute X-ray projection.
+
Args:
im: Input array, (M, N).
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -142,17 +226,20 @@ def _project(
"""
nx = im.shape
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
- # Handle out of bounds indices. In the .at call, inds >= y0 are
- # ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
- inds = jnp.where(inds >= 0, inds, ny)
+ # avoid incompatible types in the .add (scatter operation)
+ weights = weights.astype(im.dtype)
+
+ # Handle out of bounds indices by setting weight to zero
+ weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)
y = (
- jnp.zeros((len(angles), ny))
+ jnp.zeros((len(angles), ny), dtype=im.dtype)
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
- .add(im * weights)
+ .add(im * weights_valid)
)
- y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))
+ weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
+ y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid)
return y
@@ -160,8 +247,9 @@ def _project(
@partial(jax.jit, static_argnames=["nx"])
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
- ) -> ArrayLike:
- r"""
+ ) -> snp.Array:
+ r"""Compute X-ray back projection.
+
Args:
y: Input projection, (num_angles, N).
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -174,24 +262,25 @@ def _back_project(
"""
ny = y.shape[1]
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
- # Handle out of bounds indices. In the .at call, inds >= y0 are
- # ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
- inds = jnp.where(inds >= 0, inds, ny)
+ # Handle out of bounds indices by setting weight to zero
+ weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0)
# the idea: [y[0, inds[0]], y[1, inds[1]], ...]
- HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0)
+ HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0)
+
+ weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0)
HTy = HTy + jnp.sum(
- y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0
+ y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)
- return HTy
+ return HTy.astype(jnp.float32)
@staticmethod
@partial(jax.jit, static_argnames=["nx"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(
- x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float
- ) -> snp.Array:
+ x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float
+ ) -> Tuple[snp.Array, snp.Array]:
"""
Args:
@@ -199,12 +288,12 @@ def _calc_weights(
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
nx: Input image shape.
- angle: (num_angles,) array of angles in radians. Pixels are
+ angles: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
(This argument is `vmap`ed.)
y0: Location of the edge of the first detector bin.
"""
- u = [jnp.cos(angle), jnp.sin(angle)]
+ u = [jnp.cos(angles), jnp.sin(angles)]
Px0 = x0[0] * u[0] + x0[1] * u[1] - y0
Pdx = [dx[0] * u[0], dx[1] * u[1]]
Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]]))
@@ -259,10 +348,6 @@ class XRayTransform3D(LinearOperator):
:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
make these geometry arrays.
-
-
-
-
"""
def __init__(
@@ -279,7 +364,7 @@ def __init__(
"""
self.input_shape: Shape = input_shape
- self.matrices = matrices
+ self.matrices = jnp.asarray(matrices, dtype=np.float32)
self.det_shape = det_shape
self.output_shape = (len(matrices), *det_shape)
super().__init__(
@@ -344,7 +429,7 @@ def _project_single(
return proj
@staticmethod
- def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike:
+ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array:
r"""
Args:
proj: Input (set of) projection(s).
@@ -385,7 +470,7 @@ def _back_project_single(
@staticmethod
def _calc_weights(
- input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0
+ input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0
) -> snp.Array:
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...)
@@ -403,13 +488,46 @@ def _calc_weights(
left_edge = Px - w / 2
to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w)
ul_ind = jnp.floor(left_edge).astype("int32")
- ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap
ul_weight = to_next[0] * to_next[1] * (1 / w**2)
ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2)
ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2)
lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2)
+ # set weights to zero out of bounds
+ ul_weight = jnp.where(
+ (ul_ind[0] >= 0)
+ * (ul_ind[0] < det_shape[0])
+ * (ul_ind[1] >= 0)
+ * (ul_ind[1] < det_shape[1]),
+ ul_weight,
+ 0.0,
+ )
+ ur_weight = jnp.where(
+ (ul_ind[0] + 1 >= 0)
+ * (ul_ind[0] + 1 < det_shape[0])
+ * (ul_ind[1] >= 0)
+ * (ul_ind[1] < det_shape[1]),
+ ur_weight,
+ 0.0,
+ )
+ ll_weight = jnp.where(
+ (ul_ind[0] >= 0)
+ * (ul_ind[0] < det_shape[0])
+ * (ul_ind[1] + 1 >= 0)
+ * (ul_ind[1] + 1 < det_shape[1]),
+ ll_weight,
+ 0.0,
+ )
+ lr_weight = jnp.where(
+ (ul_ind[0] + 1 >= 0)
+ * (ul_ind[0] + 1 < det_shape[0])
+ * (ul_ind[1] + 1 >= 0)
+ * (ul_ind[1] + 1 < det_shape[1]),
+ lr_weight,
+ 0.0,
+ )
+
return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight
@staticmethod
diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py
index 3f50df6f..d73c60cf 100644
--- a/scico/linop/xray/astra.py
+++ b/scico/linop/xray/astra.py
@@ -462,6 +462,9 @@ class XRayTransform3D(LinearOperator): # pragma: no cover
`ASTRA toolbox `_.
The `3D geometries `__
"parallel3d" and "parallel3d_vec" are supported by this interface.
+ **NB:** A GPU is required for the primary functionality of this
+ class; if no GPU is available, projections and back projections will
+ fail with an "Unknown algorithm type" error.
The volume is fixed with respect to the coordinate system, centered
at the origin, as illustrated below:
diff --git a/scico/plot.py b/scico/plot.py
index 6c0375be..72ae4962 100644
--- a/scico/plot.py
+++ b/scico/plot.py
@@ -13,6 +13,7 @@
# This module is copied from https://github.com/bwohlberg/sporco
+import os
import sys
import numpy as np
@@ -820,7 +821,8 @@ def config_notebook_plotting():
Configure plotting functions for inline plotting within a Jupyter
Notebook shell. This function has no effect when not within a
notebook shell, and may therefore be used within a normal python
- script.
+ script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set,
+ the matplotlib backend is explicitly set to the specified value.
"""
# Check whether running within a notebook shell and have
@@ -828,8 +830,9 @@ def config_notebook_plotting():
module = sys.modules[__name__]
if _in_notebook() and module.plot.__name__ == "plot":
- # Set inline backend (i.e. %matplotlib inline) if in a notebook shell
- set_notebook_plot_backend()
+ # Set backend if specified by environment variable
+ if "MATPLOTLIB_IPYNB_BACKEND" in os.environ:
+ set_notebook_plot_backend(os.environ["MATPLOTLIB_IPYNB_BACKEND"])
# Replace plot function with a wrapper function that discards
# its return value (within a notebook with inline plotting, plots
diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py
index 72c084dd..ce265d05 100644
--- a/scico/test/flax/test_examples_flax.py
+++ b/scico/test/flax/test_examples_flax.py
@@ -12,7 +12,6 @@
generate_ct_data,
generate_foam1_images,
generate_foam2_images,
- have_astra,
have_ray,
have_xdesign,
)
@@ -75,8 +74,8 @@ def random_data_gen(seed, N, ndata):
@pytest.mark.skipif(
- not have_astra or not have_ray or not have_xdesign,
- reason="astra, ray, or xdesign package not installed",
+ not have_ray or not have_xdesign,
+ reason="ray or xdesign package not installed",
)
def test_ct_data_generation():
N = 32
@@ -90,7 +89,7 @@ def random_img_gen(seed, size, ndata):
img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen)
assert img.shape == (nimg, N, N, 1)
- assert sino.shape == (nimg, nproj, N, 1)
+ assert sino.shape == (nimg, nproj, sino.shape[2], 1)
assert fbp.shape == (nimg, N, N, 1)
diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py
index b63a88a8..03c43736 100644
--- a/scico/test/flax/test_inv.py
+++ b/scico/test/flax/test_inv.py
@@ -6,18 +6,12 @@
import jax.numpy as jnp
from jax import lax
-import pytest
-
from scico import flax as sflax
from scico import random
from scico.flax.examples import PaddedCircularConvolve, build_blur_kernel
-from scico.flax.examples.data_generation import have_astra
from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal
from scico.linop import CircularConvolve, Identity
-
-if have_astra:
- from scico.linop.xray.astra import XRayTransform2D
-
+from scico.linop.xray import XRayTransform2D
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
@@ -153,7 +147,6 @@ def test_train_odpdcnv_default(self):
np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval)
-@pytest.mark.skipif(not have_astra, reason="astra package not installed")
class TestCT:
def setup_method(self, method):
self.N = 32 # signal size
@@ -162,12 +155,9 @@ def setup_method(self, method):
xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321)
self.nproj = 60 # number of projections
- angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles
+ angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32)
self.opCT = XRayTransform2D(
- input_shape=(self.N, self.N),
- det_count=self.N,
- det_spacing=1.0,
- angles=angles,
+ input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0)
) # Radon transform operator
a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze()))
y = lax.map(a_f, xt)
diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray_2d.py
similarity index 53%
rename from scico/test/linop/xray/test_xray.py
rename to scico/test/linop/xray/test_xray_2d.py
index cd7c0dcd..3a9d7488 100644
--- a/scico/test/linop/xray/test_xray.py
+++ b/scico/test/linop/xray/test_xray_2d.py
@@ -1,11 +1,14 @@
import numpy as np
+import jax
import jax.numpy as jnp
import pytest
import scico
-from scico.linop.xray import XRayTransform2D, XRayTransform3D
+import scico.linop
+from scico.linop.xray import XRayTransform2D
+from scico.metric import psnr
@pytest.mark.filterwarnings("error")
@@ -49,7 +52,7 @@ def test_apply():
def test_apply_adjoint():
im_shape = (12, 13)
num_angles = 10
- x = jnp.ones(im_shape)
+ x = jnp.ones(im_shape, dtype=jnp.float32)
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
@@ -71,44 +74,42 @@ def test_apply_adjoint():
assert y.shape[1] == det_count
-def test_3d_scaling():
- x = jnp.zeros((4, 4, 1))
- x = x.at[1:3, 1:3, 0].set(1.0)
-
- input_shape = x.shape
- output_shape = x.shape[:2]
-
- # default spacing
- M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
- H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
- # fmt: off
- truth = jnp.array(
- [[[0.0, 0.0, 0.0, 0.0],
- [0.0, 1.0, 1.0, 0.0],
- [0.0, 1.0, 1.0, 0.0],
- [0.0, 0.0, 0.0, 0.0]]]
- ) # fmt: on
- np.testing.assert_allclose(H @ x, truth)
-
- # bigger voxels in the x (first index) direction
- M = XRayTransform3D.matrices_from_euler_angles(
- input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
- )
- H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
- # fmt: off
- truth = jnp.array(
- [[[0. , 0.5, 0.5, 0. ],
- [0. , 0.5, 0.5, 0. ],
- [0. , 0.5, 0.5, 0. ],
- [0. , 0.5, 0.5, 0. ]]]
- ) # fmt: on
- np.testing.assert_allclose(H @ x, truth)
-
- # bigger detector pixels in the x (first index) direction
- M = XRayTransform3D.matrices_from_euler_angles(
- input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0]
- )
- H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
- # fmt: off
- truth = None # fmt: on # TODO: Check this case more closely.
- # np.testing.assert_allclose(H @ x, truth)
+def test_matched_adjoint():
+ """See https://github.com/lanl/scico/issues/560."""
+ N = 16
+ det_count = int(N * 1.05 / np.sqrt(2.0))
+ dx = 1.0 / np.sqrt(2)
+ n_projection = 3
+ angles = np.linspace(0, np.pi, n_projection, endpoint=False)
+ A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx)
+ assert scico.linop.valid_adjoint(A, A.T, eps=1e-5)
+
+
+@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)])
+@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0])
+def test_fbp(dx, det_count_factor):
+ N = 256
+ x_gt = np.zeros((N, N), dtype=np.float32)
+ N4 = N // 4
+ x_gt[N4:-N4, N4:-N4] = 1.0
+
+ det_count = int(det_count_factor * N)
+ n_proj = 360
+ angles = np.linspace(0, np.pi, n_proj, endpoint=False)
+ A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx)
+ y = A(x_gt)
+ x_fbp = A.fbp(y)
+ assert psnr(x_gt, x_fbp) > 28
+
+
+def test_fbp_jit():
+ N = 64
+ x_gt = np.ones((N, N), dtype=np.float32)
+
+ det_count = N
+ n_proj = 90
+ angles = np.linspace(0, np.pi, n_proj, endpoint=False)
+ A = XRayTransform2D(x_gt.shape, angles, det_count=det_count)
+ y = A(x_gt)
+ fbp = jax.jit(A.fbp)
+ x_fbp = fbp(y)
diff --git a/scico/test/linop/xray/test_xray_3d.py b/scico/test/linop/xray/test_xray_3d.py
new file mode 100644
index 00000000..d96217a4
--- /dev/null
+++ b/scico/test/linop/xray/test_xray_3d.py
@@ -0,0 +1,66 @@
+import numpy as np
+
+import jax.numpy as jnp
+
+import scico.linop
+from scico.linop.xray import XRayTransform3D
+
+
+def test_matched_adjoint():
+ """See https://github.com/lanl/scico/issues/560."""
+ N = 16
+ det_count = int(N * 1.05 / np.sqrt(2.0))
+ n_projection = 3
+
+ input_shape = (N, N, N)
+ det_shape = (det_count, det_count)
+
+ M = XRayTransform3D.matrices_from_euler_angles(
+ input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False)
+ )
+ H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
+
+ assert scico.linop.valid_adjoint(H, H.T, eps=1e-5)
+
+
+def test_scaling():
+ x = jnp.zeros((4, 4, 1))
+ x = x.at[1:3, 1:3, 0].set(1.0)
+
+ input_shape = x.shape
+ det_shape = x.shape[:2]
+
+ # default spacing
+ M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [0.0])
+ H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
+ # fmt: off
+ truth = jnp.array(
+ [[[0.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 1.0, 0.0],
+ [0.0, 1.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0]]]
+ ) # fmt: on
+ np.testing.assert_allclose(H @ x, truth)
+
+ # bigger voxels in the x (first index) direction
+ M = XRayTransform3D.matrices_from_euler_angles(
+ input_shape, det_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
+ )
+ H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
+ # fmt: off
+ truth = jnp.array(
+ [[[0. , 0.5, 0.5, 0. ],
+ [0. , 0.5, 0.5, 0. ],
+ [0. , 0.5, 0.5, 0. ],
+ [0. , 0.5, 0.5, 0. ]]]
+ ) # fmt: on
+ np.testing.assert_allclose(H @ x, truth)
+
+ # bigger detector pixels in the x (first index) direction
+ M = XRayTransform3D.matrices_from_euler_angles(
+ input_shape, det_shape, "X", [0.0], det_spacing=[2.0, 1.0]
+ )
+ H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape)
+ # fmt: off
+ truth = None # fmt: on # TODO: Check this case more closely.
+ # np.testing.assert_allclose(H @ x, truth)