diff --git a/CHANGES.rst b/CHANGES.rst index f0dbfd7c..3354e735 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,10 +3,20 @@ SCICO Release Notes =================== -Version 0.0.6 (unreleased) +Version 0.0.7 (unreleased) +---------------------------- + +• No changes yet. + + + +Version 0.0.6 (2024-10-25) ---------------------------- • Significant changes to ``linop.xray.astra`` API. +• Rename integrated 2D X-ray transform class to + ``linop.xray.XRayTransform2D`` and add filtered back projection method + ``fbp``. • New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. @@ -18,8 +28,8 @@ Version 0.0.6 (unreleased) • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.31. -• Support ``flax`` versions 0.8.0 to 0.8.3. +• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.4.35. +• Support ``flax`` versions 0.8.0 to 0.10.0. diff --git a/MANIFEST.in b/MANIFEST.in index 9cbc39a6..d9ee9f5f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,10 +8,10 @@ include pyproject.toml include pytest.ini include requirements.txt include dev_requirements.txt -include examples/scriptcheck.sh include docs/docs_requirements.txt recursive-include scico *.py -recursive-include scico/data *.png *.npz -recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.ico -recursive-include examples *_requirements.txt *.txt *.rst *.py +recursive-include scico/data *.png *.mpk *.rst +recursive-include docs Makefile *.py *.ipynb *.rst *.bib *.css *.svg *.png *.ico +recursive-include examples *_requirements.txt *.txt *.rst *.py *.sh +recursive-include misc *.py *.sh *.rst diff --git a/README.md b/README.md index e1caf27d..cd333e64 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,12 @@ this software for published work, please cite the corresponding [JOSS Paper](https://doi.org/10.21105/joss.04722) (see bibtex entry `balke-2022-scico` in `docs/source/references.bib`). + # Installation -See the [online -documentation](https://scico.rtfd.io/en/latest/install.html) for -installation instructions. +The online documentation includes detailed +[installation instructions](https://scico.rtfd.io/en/latest/install.html). + # Usage Examples @@ -47,8 +48,11 @@ Jupyter Notebooks are provided in the to `examples/notebooks`. They are also viewable on [GitHub](https://github.com/lanl/scico-data/tree/main/notebooks) or [nbviewer](https://nbviewer.jupyter.org/github/lanl/scico-data/tree/main/notebooks/index.ipynb), -or can be run online by -[binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb). +and can be run online on +[binder](https://mybinder.org/v2/gh/lanl/scico-data/binder?labpath=notebooks%2Findex.ipynb) +or +[google colab](https://colab.research.google.com/github/lanl/scico-data/blob/colab/notebooks/index.ipynb). + # License diff --git a/data b/data index c1233896..b186bddd 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit c12338966b1b9f92554066743b1a8b664c7b7e24 +Subproject commit b186bddd170ded03be04e7921f5d86d24c92c54f diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 58ba847f..b97a52bd 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -34,12 +34,11 @@ Computed Tomography examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox - examples/ct_astra_modl_train_foam2 - examples/ct_astra_odp_train_foam2 - examples/ct_astra_unet_train_foam2 + examples/ct_modl_train_foam2 + examples/ct_odp_train_foam2 + examples/ct_unet_train_foam2 examples/ct_projector_comparison_2d examples/ct_projector_comparison_3d - examples/ct_multi_cs_tv_admm examples/ct_multi_tv_admm Deconvolution @@ -96,7 +95,7 @@ Miscellaneous examples/denoise_dncnn_universal examples/diffusercam_tv_admm examples/video_rpca_admm - examples/ct_astra_datagen_foam2 + examples/ct_datagen_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/denoise_datagen_bsds @@ -181,10 +180,10 @@ Machine Learning .. toctree:: :maxdepth: 1 - examples/ct_astra_datagen_foam2 - examples/ct_astra_modl_train_foam2 - examples/ct_astra_odp_train_foam2 - examples/ct_astra_unet_train_foam2 + examples/ct_datagen_foam2 + examples/ct_modl_train_foam2 + examples/ct_odp_train_foam2 + examples/ct_unet_train_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/deconv_modl_train_foam1 diff --git a/docs/source/references.bib b/docs/source/references.bib index 257f2428..e612e36e 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -396,6 +396,13 @@ @Article {jin-2017-unet doi = {10.1109/TIP.2017.2713099} } +@Book {kak-1988-principles, + author = {Avinash C. Kak and Malcolm Slaney}, + title = {Principles of Computerized Tomographic Imaging}, + publisher = {IEEE Press}, + year = 1988 +} + @TechReport {kamilov-2016-minimizing, author = {Ulugbek S. Kamilov}, title = {Minimizing Isotropic Total Variation without @@ -771,6 +778,7 @@ @Article {zhang-2017-dncnn pages = {3142--3155} } + @Article {zhang-2021-plug, author = {Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, diff --git a/examples/jnb.py b/examples/jnb.py index 838fed95..6fe58b98 100644 --- a/examples/jnb.py +++ b/examples/jnb.py @@ -62,10 +62,10 @@ def py_file_to_string(src): # Process remainder of source file for line in srcfile: - if re.match("^input\(", line): # end processing when input statement encountered + if re.match(r"^input\(", line): # end processing when input statement encountered break line = re.sub('^r"""', '"""', line) # remove r from r""" - line = re.sub(":cite:\`([^`]+)\`", r'', line) # fix cite format + line = re.sub(r":cite:\`([^`]+)\`", r'', line) # fix cite format lines.append(line) # Backtrack through list of lines to remove trailing newlines diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 446186a7..95e11f7c 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -33,11 +33,11 @@ Computed Tomography PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) `ct_fan_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) Fan-Beam CT Reconstruction - `ct_astra_modl_train_foam2.py `_ - CT Training and Reconstructions with MoDL - `ct_astra_odp_train_foam2.py `_ - CT Training and Reconstructions with ODP - `ct_astra_unet_train_foam2.py `_ + `ct_modl_train_foam2.py `_ + CT Training and Reconstruction with MoDL + `ct_odp_train_foam2.py `_ + CT Training and Reconstruction with ODP + `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `ct_projector_comparison_2d.py `_ 2D X-ray Transform Comparison @@ -123,7 +123,7 @@ Miscellaneous TV-Regularized 3D DiffuserCam Reconstruction `video_rpca_admm.py `_ Video Decomposition via Robust PCA - `ct_astra_datagen_foam2.py `_ + `ct_datagen_foam2.py `_ CT Data Generation for NN Training `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training @@ -239,13 +239,13 @@ Sparsity Machine Learning ^^^^^^^^^^^^^^^^ - `ct_astra_datagen_foam2.py `_ + `ct_datagen_foam2.py `_ CT Data Generation for NN Training - `ct_astra_modl_train_foam2.py `_ - CT Training and Reconstructions with MoDL - `ct_astra_odp_train_foam2.py `_ - CT Training and Reconstructions with ODP - `ct_astra_unet_train_foam2.py `_ + `ct_modl_train_foam2.py `_ + CT Training and Reconstruction with MoDL + `ct_odp_train_foam2.py `_ + CT Training and Reconstruction with ODP + `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py index b3576fda..9c462cd0 100644 --- a/examples/scripts/ct_astra_3d_tv_admm.py +++ b/examples/scripts/ct_astra_3d_tv_admm.py @@ -44,7 +44,7 @@ tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles C = XRayTransform3D( tangle.shape, det_count=[Nz, max(Nx, Ny)], det_spacing=[1.0, 1.0], angles=angles ) # CT projection operator diff --git a/examples/scripts/ct_astra_3d_tv_padmm.py b/examples/scripts/ct_astra_3d_tv_padmm.py index c6c09007..3b76a8be 100644 --- a/examples/scripts/ct_astra_3d_tv_padmm.py +++ b/examples/scripts/ct_astra_3d_tv_padmm.py @@ -44,7 +44,7 @@ tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz)) n_projection = 10 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles det_spacing = [1.0, 1.0] det_count = [Nz, max(Nx, Ny)] vectors = angle_to_vector(det_spacing, angles) @@ -56,7 +56,7 @@ y = C @ tangle # sinogram -""" +r""" Set up problem and solver. We want to minimize the functional $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py index 362c8cc3..a9dab965 100644 --- a/examples/scripts/ct_astra_noreg_pcg.py +++ b/examples/scripts/ct_astra_noreg_pcg.py @@ -44,7 +44,7 @@ Configure a CT projection operator and generate synthetic measurements. """ n_projection = N # matches the phantom size so this is not few-view CT -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = 1 / N * XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py index 5349311c..fd684e3d 100644 --- a/examples/scripts/ct_astra_tv_admm.py +++ b/examples/scripts/ct_astra_tv_admm.py @@ -38,21 +38,22 @@ """ N = 512 # phantom size np.random.seed(1234) -x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) -x_gt = snp.array(x_gt) # convert to jax type +x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) """ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +det_spacing = np.sqrt(2) +A = XRayTransform2D(x_gt.shape, det_count, det_spacing, angles) # CT projection operator y = A @ x_gt # sinogram """ -Set up ADMM solver object. +Set up problem functional and ADMM solver object. """ λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter @@ -65,9 +66,7 @@ # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() - f = loss.SquaredL2Loss(y=y, A=A) - x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py index b3c285cb..eb749373 100644 --- a/examples/scripts/ct_astra_weighted_tv_admm.py +++ b/examples/scripts/ct_astra_weighted_tv_admm.py @@ -35,7 +35,6 @@ Create a ground truth image. """ N = 512 # phantom size - np.random.seed(0) x_gt = discrete_phantom(Soil(porosity=0.80), size=384) x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64))) @@ -49,8 +48,7 @@ n_projection = 360 # number of projections Io = 1e3 # source flux 𝛼 = 1e-2 # attenuation coefficient - -angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator y_c = A @ x_gt # sinogram @@ -99,13 +97,10 @@ def postprocess(x): # shown here). ρ = 2.5e3 # ADMM penalty parameter lambda_unweighted = 3e2 # regularization strength - maxiter = 100 # number of ADMM iterations cg_tol = 1e-5 # CG relative tolerance cg_maxiter = 10 # maximum CG iterations per ADMM iteration - f = loss.SquaredL2Loss(y=y, A=A) - admm_unweighted = ADMM( f=f, g_list=[lambda_unweighted * functional.L21Norm()], @@ -137,10 +132,8 @@ def postprocess(x): $I_0$ changes. """ lambda_weighted = 5e1 - weights = snp.array(counts / Io) f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights)) - admm_weighted = ADMM( f=f, g_list=[lambda_weighted * functional.L21Norm()], @@ -151,6 +144,7 @@ def postprocess(x): subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), itstat_options={"display": True, "period": 10}, ) +print() admm_weighted.solve() x_weighted = postprocess(admm_weighted.x) diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_datagen_foam2.py similarity index 93% rename from examples/scripts/ct_astra_datagen_foam2.py rename to examples/scripts/ct_datagen_foam2.py index 4e6fb97c..2fc9be59 100644 --- a/examples/scripts/ct_astra_datagen_foam2.py +++ b/examples/scripts/ct_datagen_foam2.py @@ -14,6 +14,7 @@ """ # isort: off +import os import numpy as np import logging @@ -21,6 +22,9 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + from scico import plot from scico.flax.examples import load_ct_data diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_modl_train_foam2.py similarity index 95% rename from examples/scripts/ct_astra_modl_train_foam2.py rename to examples/scripts/ct_modl_train_foam2.py index a06d7b81..10132214 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_modl_train_foam2.py @@ -5,8 +5,8 @@ # with the package. r""" -CT Training and Reconstructions with MoDL -========================================= +CT Training and Reconstruction with MoDL +======================================== This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in @@ -65,7 +65,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform2D +from scico.linop.xray import XRayTransform2D """ Prepare parallel processing. Set an arbitrary processor count (only @@ -89,16 +89,17 @@ """ -Build CT projection operator. +Build CT projection operator. Parameters are chosen so that the operator +is equivalent to the one used to generate the training data. """ -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), - det_spacing=1, - det_count=N, angles=angles, -) # CT projection operator -A = (1.0 / N) * A # normalized + det_count=int(N * 1.05 / np.sqrt(2.0)), + dx=1.0 / np.sqrt(2), +) +A = (1.0 / N) * A # normalize projection operator """ diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 95711679..df72e387 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -38,15 +38,14 @@ np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) -det_count = N -det_spacing = np.sqrt(2) - """ Define CT geometry and construct array of (approximately) equivalent projectors. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +det_spacing = np.sqrt(2) projectors = { "astra": astra.XRayTransform2D( x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0 diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_odp_train_foam2.py similarity index 93% rename from examples/scripts/ct_astra_odp_train_foam2.py rename to examples/scripts/ct_odp_train_foam2.py index 03753c1b..cad279bb 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_odp_train_foam2.py @@ -5,8 +5,8 @@ # with the package. r""" -CT Training and Reconstructions with ODP -======================================== +CT Training and Reconstruction with ODP +======================================= This example demonstrates the training of the unrolled optimization with deep priors (ODP) gradient descent architecture described in @@ -72,7 +72,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform2D +from scico.linop.xray import XRayTransform2D platform = get_backend().platform @@ -92,21 +92,22 @@ """ -Build CT projection operator. +Build CT projection operator. Parameters are chosen so that the operator +is equivalent to the one used to generate the training data. """ -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), - det_spacing=1, - det_count=N, angles=angles, -) # CT projection operator -A = (1.0 / N) * A # normalized + det_count=int(N * 1.05 / np.sqrt(2.0)), + dx=1.0 / np.sqrt(2), +) +A = (1.0 / N) * A # normalize projection operator """ Build training and testing structures. Inputs are the sinograms and -outpus are the original generated foams. Keep training and testing +outputs are the original generated foams. Keep training and testing partitions. """ numtr = 320 diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py index 2e5d02d3..0a47c7b0 100644 --- a/examples/scripts/ct_projector_comparison_2d.py +++ b/examples/scripts/ct_projector_comparison_2d.py @@ -29,9 +29,6 @@ Create a ground truth image. """ N = 512 - -det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) - x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jnp.array(x_gt) @@ -41,17 +38,18 @@ """ num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) +det_count = int(N * 1.02 / jnp.sqrt(2.0)) timer = Timer() projectors = {} timer.start("scico_init") -projectors["scico"] = XRayTransform2D((N, N), angles) +projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count) timer.stop("scico_init") timer.start("astra_init") projectors["astra"] = astra.XRayTransform2D( - (N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0 + (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0 ) timer.stop("astra_init") diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py index 8592b44f..e152c6ff 100644 --- a/examples/scripts/ct_svmbir_tv_multi.py +++ b/examples/scripts/ct_svmbir_tv_multi.py @@ -4,7 +4,7 @@ # and user license can be found in the 'LICENSE.txt' file distributed # with the package. -""" +r""" TV-Regularized CT Reconstruction (Multiple Algorithms) ====================================================== @@ -51,7 +51,7 @@ """ num_angles = int(N / 2) num_channels = N -angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32) +angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32) A = XRayTransform(x_gt.shape, angles, num_channels) sino = A @ x_gt @@ -87,12 +87,9 @@ """ x0 = snp.array(x_mrf) weights = snp.array(weights) - λ = 1e-1 # ℓ1 norm regularization parameter - f = SVMBIRSquaredL2Loss(y=y, A=A, W=Diagonal(weights), scale=0.5) g = λ * functional.L21Norm() # regularization functional - # The append=0 option makes the results of horizontal and vertical finite # differences the same shape, which is required for the L21Norm. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) @@ -112,6 +109,7 @@ itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") +print("ADMM:") x_admm = solve_admm.solve() hist_admm = solve_admm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_admm):.2f} dB\n") @@ -130,6 +128,7 @@ maxiter=50, itstat_options={"display": True, "period": 10}, ) +print("Linearized ADMM:") x_ladmm = solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_ladmm):.2f} dB\n") @@ -148,6 +147,7 @@ maxiter=50, itstat_options={"display": True, "period": 10}, ) +print("PDHG:") x_pdhg = solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) print(f"PSNR: {metric.psnr(x_gt, x_pdhg):.2f} dB\n") diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py index ec48d4ea..c66a802b 100644 --- a/examples/scripts/ct_tv_admm.py +++ b/examples/scripts/ct_tv_admm.py @@ -45,15 +45,19 @@ Configure CT projection operator and generate synthetic measurements. """ n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles -A = XRayTransform2D((N, N), angles) # CT projection operator +angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles +det_count = int(N * 1.05 / np.sqrt(2.0)) +dx = 1.0 / np.sqrt(2) +A = XRayTransform2D( + (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx +) # CT projection operator y = A @ x_gt # sinogram """ -Set up ADMM solver object. +Set up problem functional and ADMM solver object. """ -λ = 2e0 # L1 norm regularization parameter +λ = 2e0 # ℓ1 norm regularization parameter ρ = 5e0 # ADMM penalty parameter maxiter = 25 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance @@ -64,10 +68,8 @@ # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() - f = loss.SquaredL2Loss(y=y, A=A) - -x0 = snp.clip(A.T(y), 0, 1.0) +x0 = snp.clip(A.fbp(y), 0, 1.0) solver = ADMM( f=f, @@ -94,18 +96,26 @@ Show the recovered image. """ -fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5)) +fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5)) plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0]) +plot.imview( + x0, + title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f" + % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)), + cbar=None, + fig=fig, + ax=ax[1], +) plot.imview( x_reconstruction, title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f" % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)), fig=fig, - ax=ax[1], + ax=ax[2], ) -divider = make_axes_locatable(ax[1]) +divider = make_axes_locatable(ax[2]) cax = divider.append_axes("right", size="5%", pad=0.2) -fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units") +fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units") fig.show() diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_unet_train_foam2.py similarity index 100% rename from examples/scripts/ct_astra_unet_train_foam2.py rename to examples/scripts/ct_unet_train_foam2.py diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 4f05ba2f..e36e9fbd 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -21,9 +21,9 @@ Computed Tomography - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - - ct_astra_modl_train_foam2.py - - ct_astra_odp_train_foam2.py - - ct_astra_unet_train_foam2.py + - ct_modl_train_foam2.py + - ct_odp_train_foam2.py + - ct_unet_train_foam2.py - ct_projector_comparison_2d.py - ct_projector_comparison_3d.py - ct_multi_tv_admm.py @@ -73,7 +73,7 @@ Miscellaneous - denoise_dncnn_universal.py - diffusercam_tv_admm.py - video_rpca_admm.py - - ct_astra_datagen_foam2.py + - ct_datagen_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - denoise_datagen_bsds.py @@ -143,10 +143,10 @@ Sparsity Machine Learning ^^^^^^^^^^^^^^^^ - - ct_astra_datagen_foam2.py - - ct_astra_modl_train_foam2.py - - ct_astra_odp_train_foam2.py - - ct_astra_unet_train_foam2.py + - ct_datagen_foam2.py + - ct_modl_train_foam2.py + - ct_odp_train_foam2.py + - ct_unet_train_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - deconv_modl_train_foam1.py diff --git a/misc/conda/install_conda.sh b/misc/conda/install_conda.sh index 73defcaa..744d1836 100755 --- a/misc/conda/install_conda.sh +++ b/misc/conda/install_conda.sh @@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh export PATH="$CONDAHOME/bin:$PATH" hash -r conda config --set always_yes yes -conda install mamba -n base -c conda-forge conda update -q conda conda info -a diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh index b5aa4411..cab2b7e2 100755 --- a/misc/conda/make_conda_env.sh +++ b/misc/conda/make_conda_env.sh @@ -50,7 +50,7 @@ EOF ) # Requirements that cannot be installed via conda (i.e. have to use pip) NOCONDA=$(cat <<-EOF -flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train] +flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train] EOF ) @@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init` conda activate $ENVNM # Q: why not `source activate`? A: not always in the path # Add conda-forge channel -conda config --env --append channels conda-forge - -# Install mamba -conda install mamba -n base -c conda-forge +conda config --append channels conda-forge # Install required conda packages (and extra useful packages) -mamba install $CONDA_FLAGS $CONDAREQ ipython +conda install $CONDA_FLAGS $CONDAREQ ipython # Utility ffmpeg is required by imageio for reading mp4 video files # it can also be installed via the system package manager, .e.g. # sudo apt install ffmpeg if [ "$(which ffmpeg)" = '' ]; then - mamba install $CONDA_FLAGS ffmpeg + conda install $CONDA_FLAGS ffmpeg fi # Install jaxlib and jax diff --git a/requirements.txt b/requirements.txt index 1b0e5359..838a6b18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.33 -jax>=0.4.3,<=0.4.33 +jaxlib>=0.4.13,<=0.4.35 +jax>=0.4.13,<=0.4.35 orbax-checkpoint>=0.5.0 -flax>=0.8.0,<=0.9.0 +flax>=0.8.0,<=0.10.0 pyabel>=0.9.0 diff --git a/scico/__init__.py b/scico/__init__.py index c4050238..d71ac2a7 100644 --- a/scico/__init__.py +++ b/scico/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -8,7 +8,7 @@ solving the inverse problems that arise in scientific imaging applications. """ -__version__ = "0.0.6.dev0" +__version__ = "0.0.7.dev0" import logging import sys @@ -16,14 +16,7 @@ # isort: off # Suppress jax device warning. See https://github.com/google/jax/issues/6805 -# This only works for jax>0.3.23; for earlier versions, the getLogger -# argument should be "absl". Two filters are included here due to a change -# in jax between versions 0.4.2 and 0.4.8, both of which are supported by -# scico. -logging.getLogger("jax._src.lib.xla_bridge").addFilter( # jax 0.4.2 - logging.Filter("No GPU/TPU found, falling back to CPU.") -) -logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8 +logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8 and later logging.Filter("No GPU/TPU found, falling back to CPU.") ) diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index e7a99030..e89f10cf 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -51,16 +51,9 @@ class UnitCircle: from jax.lib.xla_bridge import get_backend from scico.linop import CircularConvolve +from scico.linop.xray import XRayTransform2D from scico.numpy import Array -try: - import astra # noqa: F401 -except ImportError: - have_astra = False -else: - have_astra = True - from scico.linop.xray.astra import XRayTransform2D - class Foam2(UnitCircle): """Foam-like material with two attenuations. @@ -218,10 +211,8 @@ def generate_ct_data( - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ - if not (have_ray and have_xdesign and have_astra): - raise RuntimeError( - "Packages ray, xdesign, and astra are required for use of this function." - ) + if not (have_ray and have_xdesign): + raise RuntimeError("Packages ray and xdesign are required for use of this function.") # Generate input data. start_time = time() @@ -234,17 +225,17 @@ def generate_ct_data( # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles - gt_sh = (size, size) - detector_spacing = 1.0 - A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator - + gt_shape = (size, size) + dx = 1.0 / np.sqrt(2) + det_count = int(size * 1.05 / np.sqrt(2.0)) + A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count) # Compute sinograms in parallel. start_time = time() if nproc > 1: # shard array imgshd = img.reshape((nproc, -1, size, size, 1)) sinoshd = batched_f(A, imgshd) - sino = sinoshd.reshape((-1, nproj, size, 1)) + sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1)) else: sino = vector_f(A, img) @@ -261,8 +252,8 @@ def generate_ct_data( # Normalize sinogram. sino = sino / size - # Shift FBP to [0,1] range. - fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min()) + # Clip FBP to [0,1] range. + fbp = np.clip(fbp, 0, 1) if verbose: # pragma: no cover platform = get_backend().platform diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 256791a7..1cbef6fc 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -117,7 +117,7 @@ def conj_prox( return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs) def grad(self, x: Union[Array, BlockArray]): - r"""Evaluates the gradient of this functional at :math:`\mb{x}`. + r"""Evaluate the gradient of this functional at :math:`\mb{x}`. Args: x: Point at which to evaluate gradient. diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 770bf627..9d459db5 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -9,7 +9,7 @@ from functools import partial -from typing import Optional +from typing import Optional, Tuple from warnings import warn import numpy as np @@ -27,7 +27,7 @@ class XRayTransform2D(LinearOperator): - """Parallel ray, single axis, 2D X-ray projector. + r"""Parallel ray, single axis, 2D X-ray projector. This implementation approximates the projection of each rectangular pixel as a boxcar function (whereas the exact projection is a @@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator): accumulation of pixel values into bins (equivalently, makes the linear operator sparse). + Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather + than 1) in order to satisfy the aforementioned spacing requirement. + `x0`, `dx`, and `y0` should be expressed in units such that the detector spacing `dy` is 1.0. """ @@ -64,9 +67,11 @@ def __init__( corresponds to summing columns, and an angle of pi/4 corresponds to summing along antidiagonals. x0: (x, y) position of the corner of the pixel `im[0,0]`. By - default, `(-input_shape / 2, -input_shape / 2)`. - dx: Image pixel side length in x- and y-direction. Should be - <= 1.0 in each dimension. By default, [1.0, 1.0]. + default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`. + dx: Image pixel side length in x- and y-direction (axis 0 and + 1 respectively). Must be set so that the width of a + projected pixel is never larger than 1.0. By default, + [:math:`\sqrt{2}/2`, :math:`\sqrt{2}/2`]. y0: Location of the edge of the first detector bin. By default, `-det_count / 2` det_count: Number of elements in detector. If ``None``, @@ -109,27 +114,106 @@ def __init__( self.y0 = y0 self.dy = 1.0 + self.fbp_filter: Optional[snp.Array] = None + self.fbp_mask: Optional[snp.Array] = None + super().__init__( input_shape=self.input_shape, + input_dtype=np.float32, output_shape=self.output_shape, + output_dtype=np.float32, eval_fn=self.project, adj_fn=self.back_project, ) def project(self, im: ArrayLike) -> snp.Array: - """Compute X-ray projection.""" + """Compute X-ray projection, equivalent to `H @ im`. + + Args: + im: Input array representing the image to project. + """ return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) def back_project(self, y: ArrayLike) -> snp.Array: - """Compute X-ray back projection""" + """Compute X-ray back projection, equivalent to `H.T @ y`. + + Args: + y: Input array representing the sinogram to back project. + """ return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) + def fbp(self, y: ArrayLike) -> snp.Array: + r"""Compute filtered back projection (FBP) inverse of projection. + + Compute the filtered back projection inverse by filtering each + row of the sinogram with the filter defined in (61) in + :cite:`kak-1988-principles` and then back projecting. The + projection angles are assumed to be evenly spaced in + :math:`[0, \pi)`; reconstruction quality may be poor if + this assumption is violated. Poor quality reconstructions should + also be expected when `dx[0]` and `dx[1]` are not equal. + + Args: + y: Input projection, (num_angles, N). + + Returns: + FBP inverse of projection. + """ + N = y.shape[1] + + if self.fbp_filter is None: + nvec = jnp.arange(N) - (N - 1) // 2 + self.fbp_filter = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) + + if self.fbp_mask is None: + unit_sino = jnp.ones(self.output_shape, dtype=np.float32) + # Threshold is multiplied by 0.99... fudge factor to account for numerical errors + # in back projection. + self.fbp_mask = self.back_project(unit_sino) >= (self.output_shape[0] * (1.0 - 1e-5)) # type: ignore + + # Apply ramp filter in the frequency domain, padding to avoid + # boundary effects + h = self.fbp_filter + hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) + yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) + hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + :, (N - 1) // 2 : -(N - 1) // 2 + ].real.astype(jnp.float32) + + x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.fbp_mask * self.back_project(hy) # type: ignore + return x + + @staticmethod + def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: + """Compute coefficients of ramp filter used in FBP. + + Compute coefficients of ramp filter used in FBP, as defined in + (61) in :cite:`kak-1988-principles`. + + Args: + x: Sampling locations at which to compute filter coefficients. + tau: Sampling rate. + + Returns: + Spatial-domain coefficients of ramp filter. + """ + # The (x == 0) term in x**2 * np.pi**2 * tau**2 + (x == 0) + # is included to avoid division by zero warnings when x == 1 + # since np.where evaluates all values for both True and False + # branches. + return jnp.where( + x == 0, + 1.0 / (4.0 * tau**2), + jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + ) + @staticmethod @partial(jax.jit, static_argnames=["ny"]) def _project( im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike ) -> snp.Array: - r""" + r"""Compute X-ray projection. + Args: im: Input array, (M, N). x0: (x, y) position of the corner of the pixel im[0,0]. @@ -142,17 +226,20 @@ def _project( """ nx = im.shape inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) + # avoid incompatible types in the .add (scatter operation) + weights = weights.astype(im.dtype) + + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) y = ( - jnp.zeros((len(angles), ny)) + jnp.zeros((len(angles), ny), dtype=im.dtype) .at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] - .add(im * weights) + .add(im * weights_valid) ) - y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights)) + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) + y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * weights_valid) return y @@ -160,8 +247,9 @@ def _project( @partial(jax.jit, static_argnames=["nx"]) def _back_project( y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike - ) -> ArrayLike: - r""" + ) -> snp.Array: + r"""Compute X-ray back projection. + Args: y: Input projection, (num_angles, N). x0: (x, y) position of the corner of the pixel im[0,0]. @@ -174,24 +262,25 @@ def _back_project( """ ny = y.shape[1] inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) - # Handle out of bounds indices. In the .at call, inds >= y0 are - # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. - inds = jnp.where(inds >= 0, inds, ny) + # Handle out of bounds indices by setting weight to zero + weights_valid = jnp.where((inds >= 0) * (inds < ny), weights, 0.0) # the idea: [y[0, inds[0]], y[1, inds[1]], ...] - HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0) + HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights_valid, axis=0) + + weights_valid = jnp.where((inds + 1 >= 0) * (inds + 1 < ny), 1 - weights, 0.0) HTy = HTy + jnp.sum( - y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0 + y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0 ) - return HTy + return HTy.astype(jnp.float32) @staticmethod @partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) def _calc_weights( - x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float - ) -> snp.Array: + x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float + ) -> Tuple[snp.Array, snp.Array]: """ Args: @@ -199,12 +288,12 @@ def _calc_weights( dx: Pixel side length in x- and y-direction. Units are such that the detector bins have length 1.0. nx: Input image shape. - angle: (num_angles,) array of angles in radians. Pixels are + angles: (num_angles,) array of angles in radians. Pixels are projected onto units vectors pointing in these directions. (This argument is `vmap`ed.) y0: Location of the edge of the first detector bin. """ - u = [jnp.cos(angle), jnp.sin(angle)] + u = [jnp.cos(angles), jnp.sin(angles)] Px0 = x0[0] * u[0] + x0[1] * u[1] - y0 Pdx = [dx[0] * u[0], dx[1] * u[1]] Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]])) @@ -259,10 +348,6 @@ class XRayTransform3D(LinearOperator): :meth:`XRayTransform3D.matrices_from_euler_angles` can help to make these geometry arrays. - - - - """ def __init__( @@ -279,7 +364,7 @@ def __init__( """ self.input_shape: Shape = input_shape - self.matrices = matrices + self.matrices = jnp.asarray(matrices, dtype=np.float32) self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) super().__init__( @@ -344,7 +429,7 @@ def _project_single( return proj @staticmethod - def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike: + def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> snp.Array: r""" Args: proj: Input (set of) projection(s). @@ -385,7 +470,7 @@ def _back_project_single( @staticmethod def _calc_weights( - input_shape: Shape, matrix: snp.Array, output_shape: Shape, slice_offset: int = 0 + input_shape: Shape, matrix: snp.Array, det_shape: Shape, slice_offset: int = 0 ) -> snp.Array: # pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5) x = jnp.mgrid[: input_shape[0], : input_shape[1], : input_shape[2]] + 0.5 # (3, ...) @@ -403,13 +488,46 @@ def _calc_weights( left_edge = Px - w / 2 to_next = jnp.minimum(jnp.ceil(left_edge) - left_edge, w) ul_ind = jnp.floor(left_edge).astype("int32") - ul_ind = jnp.where(ul_ind < 0, max(output_shape), ul_ind) # otherwise negative values wrap ul_weight = to_next[0] * to_next[1] * (1 / w**2) ur_weight = (w - to_next[0]) * to_next[1] * (1 / w**2) ll_weight = to_next[0] * (w - to_next[1]) * (1 / w**2) lr_weight = (w - to_next[0]) * (w - to_next[1]) * (1 / w**2) + # set weights to zero out of bounds + ul_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ul_weight, + 0.0, + ) + ur_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] >= 0) + * (ul_ind[1] < det_shape[1]), + ur_weight, + 0.0, + ) + ll_weight = jnp.where( + (ul_ind[0] >= 0) + * (ul_ind[0] < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + ll_weight, + 0.0, + ) + lr_weight = jnp.where( + (ul_ind[0] + 1 >= 0) + * (ul_ind[0] + 1 < det_shape[0]) + * (ul_ind[1] + 1 >= 0) + * (ul_ind[1] + 1 < det_shape[1]), + lr_weight, + 0.0, + ) + return ul_ind, ul_weight, ur_weight, ll_weight, lr_weight @staticmethod diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index 3f50df6f..d73c60cf 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -462,6 +462,9 @@ class XRayTransform3D(LinearOperator): # pragma: no cover `ASTRA toolbox `_. The `3D geometries `__ "parallel3d" and "parallel3d_vec" are supported by this interface. + **NB:** A GPU is required for the primary functionality of this + class; if no GPU is available, projections and back projections will + fail with an "Unknown algorithm type" error. The volume is fixed with respect to the coordinate system, centered at the origin, as illustrated below: diff --git a/scico/plot.py b/scico/plot.py index 6c0375be..72ae4962 100644 --- a/scico/plot.py +++ b/scico/plot.py @@ -13,6 +13,7 @@ # This module is copied from https://github.com/bwohlberg/sporco +import os import sys import numpy as np @@ -820,7 +821,8 @@ def config_notebook_plotting(): Configure plotting functions for inline plotting within a Jupyter Notebook shell. This function has no effect when not within a notebook shell, and may therefore be used within a normal python - script. + script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set, + the matplotlib backend is explicitly set to the specified value. """ # Check whether running within a notebook shell and have @@ -828,8 +830,9 @@ def config_notebook_plotting(): module = sys.modules[__name__] if _in_notebook() and module.plot.__name__ == "plot": - # Set inline backend (i.e. %matplotlib inline) if in a notebook shell - set_notebook_plot_backend() + # Set backend if specified by environment variable + if "MATPLOTLIB_IPYNB_BACKEND" in os.environ: + set_notebook_plot_backend(os.environ["MATPLOTLIB_IPYNB_BACKEND"]) # Replace plot function with a wrapper function that discards # its return value (within a notebook with inline plotting, plots diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index 72c084dd..ce265d05 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -12,7 +12,6 @@ generate_ct_data, generate_foam1_images, generate_foam2_images, - have_astra, have_ray, have_xdesign, ) @@ -75,8 +74,8 @@ def random_data_gen(seed, N, ndata): @pytest.mark.skipif( - not have_astra or not have_ray or not have_xdesign, - reason="astra, ray, or xdesign package not installed", + not have_ray or not have_xdesign, + reason="ray or xdesign package not installed", ) def test_ct_data_generation(): N = 32 @@ -90,7 +89,7 @@ def random_img_gen(seed, size, ndata): img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) + assert sino.shape == (nimg, nproj, sino.shape[2], 1) assert fbp.shape == (nimg, N, N, 1) diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py index b63a88a8..03c43736 100644 --- a/scico/test/flax/test_inv.py +++ b/scico/test/flax/test_inv.py @@ -6,18 +6,12 @@ import jax.numpy as jnp from jax import lax -import pytest - from scico import flax as sflax from scico import random from scico.flax.examples import PaddedCircularConvolve, build_blur_kernel -from scico.flax.examples.data_generation import have_astra from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal from scico.linop import CircularConvolve, Identity - -if have_astra: - from scico.linop.xray.astra import XRayTransform2D - +from scico.linop.xray import XRayTransform2D os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" @@ -153,7 +147,6 @@ def test_train_odpdcnv_default(self): np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") class TestCT: def setup_method(self, method): self.N = 32 # signal size @@ -162,12 +155,9 @@ def setup_method(self, method): xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321) self.nproj = 60 # number of projections - angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles + angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32) self.opCT = XRayTransform2D( - input_shape=(self.N, self.N), - det_count=self.N, - det_spacing=1.0, - angles=angles, + input_shape=(self.N, self.N), det_count=self.N, angles=angles, dx=0.9999 / np.sqrt(2.0) ) # Radon transform operator a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze())) y = lax.map(a_f, xt) diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray_2d.py similarity index 53% rename from scico/test/linop/xray/test_xray.py rename to scico/test/linop/xray/test_xray_2d.py index cd7c0dcd..3a9d7488 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray_2d.py @@ -1,11 +1,14 @@ import numpy as np +import jax import jax.numpy as jnp import pytest import scico -from scico.linop.xray import XRayTransform2D, XRayTransform3D +import scico.linop +from scico.linop.xray import XRayTransform2D +from scico.metric import psnr @pytest.mark.filterwarnings("error") @@ -49,7 +52,7 @@ def test_apply(): def test_apply_adjoint(): im_shape = (12, 13) num_angles = 10 - x = jnp.ones(im_shape) + x = jnp.ones(im_shape, dtype=jnp.float32) angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) @@ -71,44 +74,42 @@ def test_apply_adjoint(): assert y.shape[1] == det_count -def test_3d_scaling(): - x = jnp.zeros((4, 4, 1)) - x = x.at[1:3, 1:3, 0].set(1.0) - - input_shape = x.shape - output_shape = x.shape[:2] - - # default spacing - M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - # fmt: off - truth = jnp.array( - [[[0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0]]] - ) # fmt: on - np.testing.assert_allclose(H @ x, truth) - - # bigger voxels in the x (first index) direction - M = XRayTransform3D.matrices_from_euler_angles( - input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] - ) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - # fmt: off - truth = jnp.array( - [[[0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ], - [0. , 0.5, 0.5, 0. ]]] - ) # fmt: on - np.testing.assert_allclose(H @ x, truth) - - # bigger detector pixels in the x (first index) direction - M = XRayTransform3D.matrices_from_euler_angles( - input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0] - ) - H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) - # fmt: off - truth = None # fmt: on # TODO: Check this case more closely. - # np.testing.assert_allclose(H @ x, truth) +def test_matched_adjoint(): + """See https://github.com/lanl/scico/issues/560.""" + N = 16 + det_count = int(N * 1.05 / np.sqrt(2.0)) + dx = 1.0 / np.sqrt(2) + n_projection = 3 + angles = np.linspace(0, np.pi, n_projection, endpoint=False) + A = XRayTransform2D((N, N), angles, det_count=det_count, dx=dx) + assert scico.linop.valid_adjoint(A, A.T, eps=1e-5) + + +@pytest.mark.parametrize("dx", [0.5, 1.0 / np.sqrt(2)]) +@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) +def test_fbp(dx, det_count_factor): + N = 256 + x_gt = np.zeros((N, N), dtype=np.float32) + N4 = N // 4 + x_gt[N4:-N4, N4:-N4] = 1.0 + + det_count = int(det_count_factor * N) + n_proj = 360 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count, dx=dx) + y = A(x_gt) + x_fbp = A.fbp(y) + assert psnr(x_gt, x_fbp) > 28 + + +def test_fbp_jit(): + N = 64 + x_gt = np.ones((N, N), dtype=np.float32) + + det_count = N + n_proj = 90 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count) + y = A(x_gt) + fbp = jax.jit(A.fbp) + x_fbp = fbp(y) diff --git a/scico/test/linop/xray/test_xray_3d.py b/scico/test/linop/xray/test_xray_3d.py new file mode 100644 index 00000000..d96217a4 --- /dev/null +++ b/scico/test/linop/xray/test_xray_3d.py @@ -0,0 +1,66 @@ +import numpy as np + +import jax.numpy as jnp + +import scico.linop +from scico.linop.xray import XRayTransform3D + + +def test_matched_adjoint(): + """See https://github.com/lanl/scico/issues/560.""" + N = 16 + det_count = int(N * 1.05 / np.sqrt(2.0)) + n_projection = 3 + + input_shape = (N, N, N) + det_shape = (det_count, det_count) + + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", np.linspace(0, np.pi, n_projection, endpoint=False) + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + + assert scico.linop.valid_adjoint(H, H.T, eps=1e-5) + + +def test_scaling(): + x = jnp.zeros((4, 4, 1)) + x = x.at[1:3, 1:3, 0].set(1.0) + + input_shape = x.shape + det_shape = x.shape[:2] + + # default spacing + M = XRayTransform3D.matrices_from_euler_angles(input_shape, det_shape, "X", [0.0]) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = jnp.array( + [[[0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger voxels in the x (first index) direction + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = jnp.array( + [[[0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ], + [0. , 0.5, 0.5, 0. ]]] + ) # fmt: on + np.testing.assert_allclose(H @ x, truth) + + # bigger detector pixels in the x (first index) direction + M = XRayTransform3D.matrices_from_euler_angles( + input_shape, det_shape, "X", [0.0], det_spacing=[2.0, 1.0] + ) + H = XRayTransform3D(input_shape, matrices=M, det_shape=det_shape) + # fmt: off + truth = None # fmt: on # TODO: Check this case more closely. + # np.testing.assert_allclose(H @ x, truth)