diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index e8c0a1705..67b506178 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -60,7 +60,8 @@ jobs: pip install pytest-split pip install -r requirements.txt pip install -r dev_requirements.txt - conda install -c conda-forge svmbir>=0.3.3 + # svmbir install temporarily disabled due to import errors + #conda install -c conda-forge svmbir>=0.3.3 conda install -c conda-forge astra-toolbox conda install -c conda-forge pyyaml pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version diff --git a/docs/source/install.rst b/docs/source/install.rst index 0931ef820..a68ba3272 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -87,7 +87,7 @@ can be installed from PyPI From GitHub ----------- -SCICO can be downloaded from the `GitHub repo +The development version of SCICO can be downloaded from the `GitHub repo `_. Note that, since the SCICO repo has a submodule, it should be cloned via the command :: @@ -102,6 +102,13 @@ Install using the commands pip install -e . +If a clone of the SCICO repository is not needed, it is simpler to +install directly using ``pip`` +:: + + pip install git+https://github.com/lanl/scico + + GPU Support ----------- diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index be218d5eb..a06d7b81a 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -54,6 +54,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax @@ -67,7 +72,7 @@ applies if GPU is not available). """ os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index 4a8355e36..03753c1b7 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -61,6 +61,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax @@ -70,7 +75,7 @@ from scico.linop.xray.astra import XRayTransform2D -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_astra_unet_train_foam2.py index bae623b3b..fc88e7e3b 100644 --- a/examples/scripts/ct_astra_unet_train_foam2.py +++ b/examples/scripts/ct_astra_unet_train_foam2.py @@ -27,6 +27,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + import numpy as np from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -36,7 +41,7 @@ from scico.flax.examples import load_ct_data -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index 69b19d939..38bdf3daf 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -58,6 +58,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax @@ -67,7 +72,7 @@ from scico.linop import CircularConvolve -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index 9887fe894..aa49ec919 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -66,6 +66,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax @@ -75,7 +80,7 @@ from scico.linop import CircularConvolve -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index ac9fcb755..dc3da300a 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -24,6 +24,11 @@ import jax +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax @@ -31,7 +36,7 @@ from scico.flax.examples import load_image_data -platform = jax.lib.xla_bridge.get_backend().platform +platform = get_backend().platform print("Platform: ", platform) diff --git a/requirements.txt b/requirements.txt index 8bfc6f165..1b0e5359e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.31 -jax>=0.4.3,<=0.4.31 -orbax-checkpoint<=0.5.7 -flax>=0.8.0,<=0.8.3 +jaxlib>=0.4.3,<=0.4.33 +jax>=0.4.3,<=0.4.33 +orbax-checkpoint>=0.5.0 +flax>=0.8.0,<=0.9.0 pyabel>=0.9.0 diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index ca7d3c073..e7a990301 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -45,6 +45,11 @@ class UnitCircle: import jax import jax.numpy as jnp +try: + from jax.extend.backend import get_backend # introduced in jax 0.4.33 +except ImportError: + from jax.lib.xla_bridge import get_backend + from scico.linop import CircularConvolve from scico.numpy import Array @@ -260,7 +265,7 @@ def generate_ct_data( fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min()) if verbose: # pragma: no cover - platform = jax.lib.xla_bridge.get_backend().platform + platform = get_backend().platform print(f"{'Platform':26s}{':':4s}{platform}") print(f"{'Device count':26s}{':':4s}{jax.device_count()}") print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}") @@ -333,7 +338,7 @@ def generate_blur_data( blurn = jnp.clip(blurn, 0, 1) if verbose: # pragma: no cover - platform = jax.lib.xla_bridge.get_backend().platform + platform = get_backend().platform print(f"{'Platform':26s}{':':4s}{platform}") print(f"{'Device count':26s}{':':4s}{jax.device_count()}") print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}") diff --git a/scico/flax/train/checkpoints.py b/scico/flax/train/checkpoints.py index dac4fd872..fc6460d42 100644 --- a/scico/flax/train/checkpoints.py +++ b/scico/flax/train/checkpoints.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,9 +13,7 @@ import jax -import orbax.checkpoint - -from flax.training import orbax_utils +import orbax.checkpoint as ocp from .state import TrainState from .typed_dict import ConfigDict @@ -48,13 +46,20 @@ def checkpoint_restore( if isinstance(workdir_, str): workdir_ = Path(workdir_) if workdir_.exists(): - orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() - checkpoint_manager = orbax.checkpoint.CheckpointManager(workdir_, orbax_checkpointer) - step = checkpoint_manager.latest_step() + options = ocp.CheckpointManagerOptions() + mngr = ocp.CheckpointManager( + workdir_, + item_names=("state", "config"), + options=options, + ) + step = mngr.latest_step() if step is not None: - target = {"state": state, "config": {}} - ckpt = checkpoint_manager.restore(step, items=target) - state = ckpt["state"] + restored = mngr.restore( + step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state)) + ) + mngr.wait_until_finished() + mngr.close() + state = restored.state elif not ok_no_ckpt: raise FileNotFoundError("Could not read from checkpoint: " + str(workdir)) @@ -74,13 +79,23 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P workdir: Path in which to store checkpoint files. """ if jax.process_index() == 0: - orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() - # Bundle config and model parameters together - ckpt = {"state": state, "config": config} - save_args = orbax_utils.save_args_from_target(ckpt) - options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3, create=True) - checkpoint_manager = orbax.checkpoint.CheckpointManager( - workdir, orbax_checkpointer, options + options = ocp.CheckpointManagerOptions(max_to_keep=3, create=True) + mngr = ocp.CheckpointManager( + workdir, + item_names=("state", "config"), + options=options, ) step = int(state.step) - checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args}) + # Remove non-serializable partial functools in post_lst if it exists + config_ = config.copy() + if "post_lst" in config_: + config_.pop("post_lst", None) # type: ignore + mngr.save( + step, + args=ocp.args.Composite( + state=ocp.args.StandardSave(state), + config=ocp.args.JsonSave(config_), + ), + ) + mngr.wait_until_finished() + mngr.close() diff --git a/scico/functional/_denoiser.py b/scico/functional/_denoiser.py index 59350ce42..4d6230997 100644 --- a/scico/functional/_denoiser.py +++ b/scico/functional/_denoiser.py @@ -128,7 +128,8 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply DnCNN denoiser. *Warning*: The `lam` parameter is ignored, and has no effect on - the output. + the output for :class:`.DnCNN` objects initialized with + :code:`variant` parameter values other than `6N` and `17N`. Args: x: Input array. diff --git a/scico/numpy/_blockarray.py b/scico/numpy/_blockarray.py index c01b25d2f..6e97edf47 100644 --- a/scico/numpy/_blockarray.py +++ b/scico/numpy/_blockarray.py @@ -8,7 +8,7 @@ """Block array class.""" import inspect -from functools import wraps +from functools import WRAPPER_ASSIGNMENTS, wraps from typing import Callable import jax @@ -174,10 +174,15 @@ def prop_ba(self): def _da_method_wrapper(method_name): method = getattr(Array, method_name) - if method.__name__ is None: - return method + # Don't try to set attributes that are None. Not clear why some + # functions/methods (e.g. block_until_ready) have None values + # for these attributes. + wrapper_assignments = WRAPPER_ASSIGNMENTS + for attr in ("__name__", "__qualname__"): + if getattr(method, attr) is None: + wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr) - @wraps(method) + @wraps(method, assigned=wrapper_assignments) def method_ba(self, *args, **kwargs): result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self) diff --git a/scico/numpy/_wrapped_function_lists.py b/scico/numpy/_wrapped_function_lists.py index 6e9c3a163..217bfc43a 100644 --- a/scico/numpy/_wrapped_function_lists.py +++ b/scico/numpy/_wrapped_function_lists.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed @@ -84,7 +84,7 @@ "arccosh", "arctanh", "around", - "round_", + "round", "rint", "fix", "floor",