Skip to content

Commit

Permalink
Merge branch 'main' into mike/xray_tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg authored Oct 1, 2024
2 parents cf403bf + 8dc1a2a commit 9f1eaf6
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 39 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ jobs:
pip install pytest-split
pip install -r requirements.txt
pip install -r dev_requirements.txt
conda install -c conda-forge svmbir>=0.3.3
# svmbir install temporarily disabled due to import errors
#conda install -c conda-forge svmbir>=0.3.3
conda install -c conda-forge astra-toolbox
conda install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
Expand Down
9 changes: 8 additions & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ can be installed from PyPI
From GitHub
-----------

SCICO can be downloaded from the `GitHub repo
The development version of SCICO can be downloaded from the `GitHub repo
<https://github.com/lanl/scico>`_. Note that, since the SCICO repo has
a submodule, it should be cloned via the command
::
Expand All @@ -102,6 +102,13 @@ Install using the commands
pip install -e .


If a clone of the SCICO repository is not needed, it is simpler to
install directly using ``pip``
::

pip install git+https://github.com/lanl/scico



GPU Support
-----------
Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -67,7 +72,7 @@
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -70,7 +75,7 @@
from scico.linop.xray.astra import XRayTransform2D


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -36,7 +41,7 @@
from scico.flax.examples import load_ct_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -67,7 +72,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -75,7 +80,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_image_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.31
jax>=0.4.3,<=0.4.31
orbax-checkpoint<=0.5.7
flax>=0.8.0,<=0.8.3
jaxlib>=0.4.3,<=0.4.33
jax>=0.4.3,<=0.4.33
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
9 changes: 7 additions & 2 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class UnitCircle:
import jax
import jax.numpy as jnp

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from scico.linop import CircularConvolve
from scico.numpy import Array

Expand Down Expand Up @@ -260,7 +265,7 @@ def generate_ct_data(
fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
Expand Down Expand Up @@ -333,7 +338,7 @@ def generate_blur_data(
blurn = jnp.clip(blurn, 0, 1)

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
Expand Down
51 changes: 33 additions & 18 deletions scico/flax/train/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,9 +13,7 @@

import jax

import orbax.checkpoint

from flax.training import orbax_utils
import orbax.checkpoint as ocp

from .state import TrainState
from .typed_dict import ConfigDict
Expand Down Expand Up @@ -48,13 +46,20 @@ def checkpoint_restore(
if isinstance(workdir_, str):
workdir_ = Path(workdir_)
if workdir_.exists():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint_manager = orbax.checkpoint.CheckpointManager(workdir_, orbax_checkpointer)
step = checkpoint_manager.latest_step()
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
workdir_,
item_names=("state", "config"),
options=options,
)
step = mngr.latest_step()
if step is not None:
target = {"state": state, "config": {}}
ckpt = checkpoint_manager.restore(step, items=target)
state = ckpt["state"]
restored = mngr.restore(
step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state))
)
mngr.wait_until_finished()
mngr.close()
state = restored.state
elif not ok_no_ckpt:
raise FileNotFoundError("Could not read from checkpoint: " + str(workdir))

Expand All @@ -74,13 +79,23 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P
workdir: Path in which to store checkpoint files.
"""
if jax.process_index() == 0:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# Bundle config and model parameters together
ckpt = {"state": state, "config": config}
save_args = orbax_utils.save_args_from_target(ckpt)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
workdir, orbax_checkpointer, options
options = ocp.CheckpointManagerOptions(max_to_keep=3, create=True)
mngr = ocp.CheckpointManager(
workdir,
item_names=("state", "config"),
options=options,
)
step = int(state.step)
checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})
# Remove non-serializable partial functools in post_lst if it exists
config_ = config.copy()
if "post_lst" in config_:
config_.pop("post_lst", None) # type: ignore
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
config=ocp.args.JsonSave(config_),
),
)
mngr.wait_until_finished()
mngr.close()
3 changes: 2 additions & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
r"""Apply DnCNN denoiser.
*Warning*: The `lam` parameter is ignored, and has no effect on
the output.
the output for :class:`.DnCNN` objects initialized with
:code:`variant` parameter values other than `6N` and `17N`.
Args:
x: Input array.
Expand Down
13 changes: 9 additions & 4 deletions scico/numpy/_blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""Block array class."""

import inspect
from functools import wraps
from functools import WRAPPER_ASSIGNMENTS, wraps
from typing import Callable

import jax
Expand Down Expand Up @@ -174,10 +174,15 @@ def prop_ba(self):
def _da_method_wrapper(method_name):
method = getattr(Array, method_name)

if method.__name__ is None:
return method
# Don't try to set attributes that are None. Not clear why some
# functions/methods (e.g. block_until_ready) have None values
# for these attributes.
wrapper_assignments = WRAPPER_ASSIGNMENTS
for attr in ("__name__", "__qualname__"):
if getattr(method, attr) is None:
wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr)

@wraps(method)
@wraps(method, assigned=wrapper_assignments)
def method_ba(self, *args, **kwargs):
result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self)

Expand Down
4 changes: 2 additions & 2 deletions scico/numpy/_wrapped_function_lists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down Expand Up @@ -84,7 +84,7 @@
"arccosh",
"arctanh",
"around",
"round_",
"round",
"rint",
"fix",
"floor",
Expand Down

0 comments on commit 9f1eaf6

Please sign in to comment.