diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 881519936b..527247799f 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -76,3 +76,8 @@ State Cacher ------------ .. automodule:: monai.utils.state_cacher :members: + +Component store +--------------- +.. autoclass:: monai.utils.component_store.ComponentStore + :members: diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 82f944ccb8..2c32eb2cf4 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -18,6 +18,8 @@ from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( + AdversarialIterationEvents, + AdversarialKeys, AlgoKeys, Average, BlendMode, @@ -47,6 +49,8 @@ MetricReduction, NdimageMode, NumpyPadMode, + OrderingTransformations, + OrderingType, PatchKeys, PostFix, ProbMapKeys, @@ -95,6 +99,8 @@ str2bool, str2list, to_tuple_of_dictionaries, + unsqueeze_left, + unsqueeze_right, zip_with, ) from .module import ( diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 409f979c56..a0847dd76c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -13,8 +13,11 @@ import random from enum import Enum +from typing import TYPE_CHECKING +from monai.config import IgniteInfo from monai.utils import deprecated +from monai.utils.module import min_version, optional_import __all__ = [ "StrEnum", @@ -88,6 +91,14 @@ def __repr__(self): return self.value +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import( + "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" + ) + + class NumpyPadMode(StrEnum): """ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -692,3 +703,57 @@ class AlgoKeys(StrEnum): ALGO = "algo_instance" IS_TRAINED = "is_trained" SCORE = "best_metric" + + +class AdversarialKeys(StrEnum): + """ + Keys used by the AdversarialTrainer. + `REALS` are real images from the batch. + `FAKES` are fake images generated by the generator. Are the same as PRED. + `REAL_LOGITS` are logits of the discriminator for the real images. + `FAKE_LOGIT` are logits of the discriminator for the fake images. + `RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function. + `GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + `DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + """ + + REALS = "reals" + REAL_LOGITS = "real_logits" + FAKES = "fakes" + FAKE_LOGITS = "fake_logits" + RECONSTRUCTION_LOSS = "reconstruction_loss" + GENERATOR_LOSS = "generator_loss" + DISCRIMINATOR_LOSS = "discriminator_loss" + + +class AdversarialIterationEvents(EventEnum): + """ + Keys used to define events as used in the AdversarialTrainer. + """ + + RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" + GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" + GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" + GENERATOR_LOSS_COMPLETED = "generator_loss_completed" + GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed" + GENERATOR_MODEL_COMPLETED = "generator_model_completed" + DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed" + DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed" + DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" + DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" + DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" + + +class OrderingType(StrEnum): + RASTER_SCAN = "raster_scan" + S_CURVE = "s_curve" + RANDOM = "random" + + +class OrderingTransformations(StrEnum): + ROTATE_90 = "rotate_90" + TRANSPOSE = "transpose" + REFLECT = "reflect" diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3b11af41b0..d6ff370f69 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool: sqrt_num = [int(math.sqrt(_num)) for _num in num] ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)] return ensure_tuple(ret) == num + + +def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: + """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(...,) + (None,) * (ndim - arr.ndim)] + + +def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: + """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py new file mode 100644 index 0000000000..2db26a6bdc --- /dev/null +++ b/tests/test_squeeze_unsqueeze.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils import unsqueeze_left, unsqueeze_right + +RIGHT_CASES = [ + (np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)), + (np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)), + (np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)), +] + + +LEFT_CASES = [ + (np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)), + (np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)), + (np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)), +] +ALL_CASES = [ + (np.random.rand(3, 4), 2, (3, 4)), + (np.random.rand(3, 4), 0, (3, 4)), + (np.random.rand(3, 4), -1, (3, 4)), + (np.array(3), 4, (1, 1, 1, 1)), + (np.array(3), 0, ()), + (np.random.rand(3, 4).astype(np.int32), 2, (3, 4)), + (np.random.rand(3, 4).astype(np.int32), 0, (3, 4)), + (np.random.rand(3, 4).astype(np.int32), -1, (3, 4)), + (np.array(3).astype(np.int32), 4, (1, 1, 1, 1)), + (np.array(3).astype(np.int32), 0, ()), + (torch.rand(3, 4), 2, (3, 4)), + (torch.rand(3, 4), 0, (3, 4)), + (torch.rand(3, 4), -1, (3, 4)), + (torch.tensor(3), 4, (1, 1, 1, 1)), + (torch.tensor(3), 0, ()), + (torch.rand(3, 4).type(torch.int32), 2, (3, 4)), + (torch.rand(3, 4).type(torch.int32), 0, (3, 4)), + (torch.rand(3, 4).type(torch.int32), -1, (3, 4)), + (torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)), + (torch.tensor(3).type(torch.int32), 0, ()), +] + + +class TestUnsqueeze(unittest.TestCase): + @parameterized.expand(RIGHT_CASES + ALL_CASES) + def test_unsqueeze_right(self, arr, ndim, shape): + self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) + + @parameterized.expand(LEFT_CASES + ALL_CASES) + def test_unsqueeze_left(self, arr, ndim, shape): + self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)