From 7aaeab0c0920d4996b50ef2c907b9186346af2fe Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 16 Oct 2023 15:15:49 +0100 Subject: [PATCH 01/13] Adds new enums --- monai/utils/__init__.py | 4 ++++ monai/utils/enums.py | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c973d4bfa1..38717f27fd 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -17,6 +17,8 @@ from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( + AdversarialIterationEvents, + AdversarialKeys, AlgoKeys, Average, BlendMode, @@ -46,6 +48,8 @@ MetricReduction, NdimageMode, NumpyPadMode, + OrderingTransformations, + OrderingType, PatchKeys, PostFix, ProbMapKeys, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 409f979c56..39f26a0647 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -13,8 +13,11 @@ import random from enum import Enum +from typing import TYPE_CHECKING +from monai.config import IgniteInfo from monai.utils import deprecated +from monai.utils.module import min_version, optional_import __all__ = [ "StrEnum", @@ -88,6 +91,14 @@ def __repr__(self): return self.value +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import( + "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" + ) + + class NumpyPadMode(StrEnum): """ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -692,3 +703,39 @@ class AlgoKeys(StrEnum): ALGO = "algo_instance" IS_TRAINED = "is_trained" SCORE = "best_metric" + + +class AdversarialKeys(StrEnum): + REALS = "reals" + REAL_LOGITS = "real_logits" + FAKES = "fakes" + FAKE_LOGITS = "fake_logits" + RECONSTRUCTION_LOSS = "reconstruction_loss" + GENERATOR_LOSS = "generator_loss" + DISCRIMINATOR_LOSS = "discriminator_loss" + + +class AdversarialIterationEvents(EventEnum): + RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" + GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" + GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" + GENERATOR_LOSS_COMPLETED = "generator_loss_completed" + GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed" + GENERATOR_MODEL_COMPLETED = "generator_model_completed" + DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed" + DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed" + DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" + DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" + DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" + + +class OrderingType(StrEnum): + RASTER_SCAN = "raster_scan" + S_CURVE = "s_curve" + RANDOM = "random" + + +class OrderingTransformations(StrEnum): + ROTATE_90 = "rotate_90" + TRANSPOSE = "transpose" + REFLECT = "reflect" From 209054931ad70b5c4efb7fbb0b0b69c6812af2bd Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 16 Oct 2023 15:51:21 +0100 Subject: [PATCH 02/13] Add misc and component store --- monai/utils/__init__.py | 1 + monai/utils/component_store.py | 117 +++++++++++++++++++++++++++++++++ monai/utils/misc.py | 10 +++ tests/test_component_store.py | 72 ++++++++++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 monai/utils/component_store.py create mode 100644 tests/test_component_store.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 38717f27fd..884e9f4fd6 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -13,6 +13,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name +from .component_store import ComponentStore from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py new file mode 100644 index 0000000000..6fd8e8884f --- /dev/null +++ b/monai/utils/component_store.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import namedtuple +from keyword import iskeyword +from textwrap import dedent, indent +from typing import Any, Callable, Iterable, TypeVar + +T = TypeVar("T") + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore: + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + + """ + + _Component = namedtuple("Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: dict[str, self._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__"): + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" + + return result + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3b11af41b0..86abe591fd 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool: sqrt_num = [int(math.sqrt(_num)) for _num in num] ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)] return ensure_tuple(ret) == num + + +def unsqueeze_right(arr: T, ndim: int) -> T: + """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(...,) + (None,) * (ndim - arr.ndim)] + + +def unsqueeze_left(arr: T, ndim: int) -> T: + """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 0000000000..614f387754 --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func) From 94f57ac21b36dd8f4eb3d63bc44209510d03135e Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 16 Oct 2023 15:55:46 +0100 Subject: [PATCH 03/13] Updates docs --- docs/source/utils.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 881519936b..1fe3825c7e 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -76,3 +76,8 @@ State Cacher ------------ .. automodule:: monai.utils.state_cacher :members: + +Component store +--------------- +.. automodule:: monai.utils.component_store + :members: \ No newline at end of file From f94e666cb8f3b30cda73d0a8189e1e39e5c4b6d8 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Mon, 16 Oct 2023 16:03:18 +0100 Subject: [PATCH 04/13] Update docs --- docs/source/utils.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 1fe3825c7e..527247799f 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -79,5 +79,5 @@ State Cacher Component store --------------- -.. automodule:: monai.utils.component_store - :members: \ No newline at end of file +.. autoclass:: monai.utils.component_store.ComponentStore + :members: From 1d95d92d695053fffc2b2bb6222745cea1610fc0 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 17 Oct 2023 11:34:08 +0100 Subject: [PATCH 05/13] Updates component store to make compliant with mypy --- monai/utils/component_store.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index 6fd8e8884f..0c9c9d2b3b 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -49,11 +49,10 @@ def _my_func(a, b): result = func(7, 6) """ - - _Component = namedtuple("Component", ("description", "value")) # internal value pair + _Component = namedtuple("_Component", ("description", "value")) # internal value pair def __init__(self, name: str, description: str) -> None: - self.components: dict[str, self._Component] = {} + self.components: dict[str, ComponentStore._Component] = {} self.name: str = name self.description: str = description From 355b4d91a3d740d94b812e7698d32a50aa127586 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 17 Oct 2023 15:58:59 +0100 Subject: [PATCH 06/13] Fixes mypy issues --- monai/utils/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 86abe591fd..d6ff370f69 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -890,11 +890,11 @@ def is_sqrt(num: Sequence[int] | int) -> bool: return ensure_tuple(ret) == num -def unsqueeze_right(arr: T, ndim: int) -> T: +def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(...,) + (None,) * (ndim - arr.ndim)] -def unsqueeze_left(arr: T, ndim: int) -> T: +def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor: """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" return arr[(None,) * (ndim - arr.ndim)] From 92044d5dbf7ec28d3e983ee69b10146969f1053f Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 17 Oct 2023 16:07:40 +0100 Subject: [PATCH 07/13] Fix mypy issues and add test --- monai/utils/__init__.py | 2 ++ monai/utils/component_store.py | 1 + tests/test_squeeze_unsqueeze.py | 47 +++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) create mode 100644 tests/test_squeeze_unsqueeze.py diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 884e9f4fd6..2c32eb2cf4 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -99,6 +99,8 @@ str2bool, str2list, to_tuple_of_dictionaries, + unsqueeze_left, + unsqueeze_right, zip_with, ) from .module import ( diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py index 0c9c9d2b3b..4de85bca23 100644 --- a/monai/utils/component_store.py +++ b/monai/utils/component_store.py @@ -49,6 +49,7 @@ def _my_func(a, b): result = func(7, 6) """ + _Component = namedtuple("_Component", ("description", "value")) # internal value pair def __init__(self, name: str, description: str) -> None: diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py new file mode 100644 index 0000000000..346eff15fc --- /dev/null +++ b/tests/test_squeeze_unsqueeze.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils import unsqueeze_left, unsqueeze_right + +RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] + +LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] + +ALL_CASES = [ + (np.random.rand(3, 4), 2, (3, 4)), + (np.random.rand(3, 4), 0, (3, 4)), + (np.random.rand(3, 4), -1, (3, 4)), + (np.array(3), 4, (1, 1, 1, 1)), + (np.array(3), 0, ()), + (torch.rand(3, 4), 2, (3, 4)), + (torch.rand(3, 4), 0, (3, 4)), + (torch.rand(3, 4), -1, (3, 4)), + (torch.tensor(3), 4, (1, 1, 1, 1)), + (torch.tensor(3), 0, ()), +] + + +class TestUnsqueeze(unittest.TestCase): + @parameterized.expand(RIGHT_CASES + ALL_CASES) + def test_unsqueeze_right(self, arr, ndim, shape): + self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) + + @parameterized.expand(LEFT_CASES + ALL_CASES) + def test_unsqueeze_left(self, arr, ndim, shape): + self.assertEqual(unsqueeze_left(arr, ndim).shape, shape) From 536ce750572de0d772ad83be43c03aad1289cd09 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 17 Oct 2023 16:29:21 +0100 Subject: [PATCH 08/13] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: 7aaeab0c0920d4996b50ef2c907b9186346af2fe I, Mark Graham , hereby add my Signed-off-by to this commit: 209054931ad70b5c4efb7fbb0b0b69c6812af2bd I, Mark Graham , hereby add my Signed-off-by to this commit: 94f57ac21b36dd8f4eb3d63bc44209510d03135e I, Mark Graham , hereby add my Signed-off-by to this commit: f94e666cb8f3b30cda73d0a8189e1e39e5c4b6d8 I, Mark Graham , hereby add my Signed-off-by to this commit: 1d95d92d695053fffc2b2bb6222745cea1610fc0 I, Mark Graham , hereby add my Signed-off-by to this commit: 355b4d91a3d740d94b812e7698d32a50aa127586 I, Mark Graham , hereby add my Signed-off-by to this commit: 92044d5dbf7ec28d3e983ee69b10146969f1053f Signed-off-by: Mark Graham --- tests/test_squeeze_unsqueeze.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index 346eff15fc..e8362f5bf7 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -22,7 +22,7 @@ RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] - + ALL_CASES = [ (np.random.rand(3, 4), 2, (3, 4)), (np.random.rand(3, 4), 0, (3, 4)), From 57e484358e48f8744264de68a58f5886448469cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:29:50 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_squeeze_unsqueeze.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index e8362f5bf7..346eff15fc 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -22,7 +22,7 @@ RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] - + ALL_CASES = [ (np.random.rand(3, 4), 2, (3, 4)), (np.random.rand(3, 4), 0, (3, 4)), From 21950bbdfdd31b9bc5ffedd61609308710c29b0d Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 17 Oct 2023 16:30:19 +0100 Subject: [PATCH 10/13] DCO commit Signed-off-by: Mark Graham --- tests/test_squeeze_unsqueeze.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index e8362f5bf7..346eff15fc 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -22,7 +22,7 @@ RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] - + ALL_CASES = [ (np.random.rand(3, 4), 2, (3, 4)), (np.random.rand(3, 4), 0, (3, 4)), From f71fda0059c29459ff80d1bd8fd865caef83c07a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 18 Oct 2023 13:15:25 +0100 Subject: [PATCH 11/13] Adds brief docstrings --- monai/utils/enums.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 39f26a0647..a0847dd76c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -706,6 +706,20 @@ class AlgoKeys(StrEnum): class AdversarialKeys(StrEnum): + """ + Keys used by the AdversarialTrainer. + `REALS` are real images from the batch. + `FAKES` are fake images generated by the generator. Are the same as PRED. + `REAL_LOGITS` are logits of the discriminator for the real images. + `FAKE_LOGIT` are logits of the discriminator for the fake images. + `RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function. + `GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the + discriminator loss for the fake images. That is backpropagated through the generator only. + `DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the + discriminator loss for the real images and the fake images. That is backpropagated through the + discriminator only. + """ + REALS = "reals" REAL_LOGITS = "real_logits" FAKES = "fakes" @@ -716,6 +730,10 @@ class AdversarialKeys(StrEnum): class AdversarialIterationEvents(EventEnum): + """ + Keys used to define events as used in the AdversarialTrainer. + """ + RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" From 64e8c9fff733c37c4f54ec974630b25b04760ebc Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 18 Oct 2023 13:48:41 +0100 Subject: [PATCH 12/13] Tests with different datatypes Signed-off-by: Mark Graham --- tests/test_squeeze_unsqueeze.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index 346eff15fc..2adcd018c0 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -19,21 +19,45 @@ from monai.utils import unsqueeze_left, unsqueeze_right -RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] +RIGHT_CASES = [ + (np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)), + (np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)), + (np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)), + (torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)), +] -LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] +LEFT_CASES = [ + (np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)), + (np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)), + (np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)), + (torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)), +] ALL_CASES = [ (np.random.rand(3, 4), 2, (3, 4)), (np.random.rand(3, 4), 0, (3, 4)), (np.random.rand(3, 4), -1, (3, 4)), (np.array(3), 4, (1, 1, 1, 1)), (np.array(3), 0, ()), + (np.random.rand(3, 4).astype(np.int32), 2, (3, 4)), + (np.random.rand(3, 4).astype(np.int32), 0, (3, 4)), + (np.random.rand(3, 4).astype(np.int32), -1, (3, 4)), + (np.array(3).astype(np.int32), 4, (1, 1, 1, 1)), + (np.array(3).astype(np.int32), 0, ()), (torch.rand(3, 4), 2, (3, 4)), (torch.rand(3, 4), 0, (3, 4)), (torch.rand(3, 4), -1, (3, 4)), (torch.tensor(3), 4, (1, 1, 1, 1)), (torch.tensor(3), 0, ()), + (torch.rand(3, 4).type(torch.int32), 2, (3, 4)), + (torch.rand(3, 4).type(torch.int32), 0, (3, 4)), + (torch.rand(3, 4).type(torch.int32), -1, (3, 4)), + (torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)), + (torch.tensor(3).type(torch.int32), 0, ()), ] From 471d25d67c81d558cd0c7e91104a51fd38941541 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 18 Oct 2023 13:49:44 +0100 Subject: [PATCH 13/13] DCO Remediation Commit for Mark Graham I, Mark Graham , hereby add my Signed-off-by to this commit: f71fda0059c29459ff80d1bd8fd865caef83c07a Signed-off-by: Mark Graham --- tests/test_squeeze_unsqueeze.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_squeeze_unsqueeze.py b/tests/test_squeeze_unsqueeze.py index 2adcd018c0..2db26a6bdc 100644 --- a/tests/test_squeeze_unsqueeze.py +++ b/tests/test_squeeze_unsqueeze.py @@ -57,7 +57,7 @@ (torch.rand(3, 4).type(torch.int32), 0, (3, 4)), (torch.rand(3, 4).type(torch.int32), -1, (3, 4)), (torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)), - (torch.tensor(3).type(torch.int32), 0, ()), + (torch.tensor(3).type(torch.int32), 0, ()), ]