Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Port MONAI Generative utils #7134

Merged
merged 17 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ State Cacher
------------
.. automodule:: monai.utils.state_cacher
:members:

Component store
---------------
.. autoclass:: monai.utils.component_store.ComponentStore
:members:
7 changes: 7 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .component_store import ComponentStore
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
AdversarialIterationEvents,
AdversarialKeys,
AlgoKeys,
Average,
BlendMode,
Expand Down Expand Up @@ -46,6 +49,8 @@
MetricReduction,
NdimageMode,
NumpyPadMode,
OrderingTransformations,
OrderingType,
PatchKeys,
PostFix,
ProbMapKeys,
Expand Down Expand Up @@ -94,6 +99,8 @@
str2bool,
str2list,
to_tuple_of_dictionaries,
unsqueeze_left,
unsqueeze_right,
zip_with,
)
from .module import (
Expand Down
117 changes: 117 additions & 0 deletions monai/utils/component_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import namedtuple
from keyword import iskeyword
from textwrap import dedent, indent
from typing import Any, Callable, Iterable, TypeVar

T = TypeVar("T")


def is_variable(name):
"""Returns True if `name` is a valid Python variable name and also not a keyword."""
return name.isidentifier() and not iskeyword(name)


class ComponentStore:
wyli marked this conversation as resolved.
Show resolved Hide resolved
"""
Represents a storage object for other objects (specifically functions) keyed to a name with a description.

These objects act as global named places for storing components for objects parameterised by component names.
Typically this is functions although other objects can be added. Printing a component store will produce a
list of members along with their docstring information if present.

Example:

.. code-block:: python

TestStore = ComponentStore("Test Store", "A test store for demo purposes")

@TestStore.add_def("my_func_name", "Some description of your function")
def _my_func(a, b):
'''A description of your function here.'''
return a * b

print(TestStore) # will print out name, description, and 'my_func_name' with the docstring

func = TestStore["my_func_name"]
result = func(7, 6)

"""

_Component = namedtuple("_Component", ("description", "value")) # internal value pair

def __init__(self, name: str, description: str) -> None:
self.components: dict[str, ComponentStore._Component] = {}
self.name: str = name
self.description: str = description

self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()

def add(self, name: str, desc: str, value: T) -> T:
"""Store the object `value` under the name `name` with description `desc`."""
if not is_variable(name):
raise ValueError("Name of component must be valid Python identifier")

self.components[name] = self._Component(desc, value)
return value

def add_def(self, name: str, desc: str) -> Callable:
"""Returns a decorator which stores the decorated function under `name` with description `desc`."""

def deco(func):
"""Decorator to add a function to a store."""
return self.add(name, desc, func)

return deco

def __contains__(self, name: str) -> bool:
"""Returns True if the given name is stored."""
return name in self.components

def __len__(self) -> int:
"""Returns the number of stored components."""
return len(self.components)

def __iter__(self) -> Iterable:
"""Yields name/component pairs."""
for k, v in self.components.items():
yield k, v.value

def __str__(self):
result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
for k, v in self.components.items():
result += f"\n* {k}:"

if hasattr(v.value, "__doc__"):
doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ")
result += f"\n{doc}\n"
else:
result += f" {v.description}"

return result

def __getattr__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.components:
return self.components[name].value
else:
return self.__getattribute__(name)

def __getitem__(self, name: str) -> Any:
"""Returns the stored object under the given name."""
if name in self.components:
return self.components[name].value
else:
raise ValueError(f"Component '{name}' not found")
47 changes: 47 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

import random
from enum import Enum
from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.utils import deprecated
from monai.utils.module import min_version, optional_import

__all__ = [
"StrEnum",
Expand Down Expand Up @@ -88,6 +91,14 @@ def __repr__(self):
return self.value


if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
EventEnum, _ = optional_import(
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
)


class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
Expand Down Expand Up @@ -692,3 +703,39 @@ class AlgoKeys(StrEnum):
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"


class AdversarialKeys(StrEnum):
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
REALS = "reals"
REAL_LOGITS = "real_logits"
FAKES = "fakes"
FAKE_LOGITS = "fake_logits"
RECONSTRUCTION_LOSS = "reconstruction_loss"
GENERATOR_LOSS = "generator_loss"
DISCRIMINATOR_LOSS = "discriminator_loss"


class AdversarialIterationEvents(EventEnum):
RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"


class OrderingType(StrEnum):
RASTER_SCAN = "raster_scan"
S_CURVE = "s_curve"
RANDOM = "random"


class OrderingTransformations(StrEnum):
ROTATE_90 = "rotate_90"
TRANSPOSE = "transpose"
REFLECT = "reflect"
10 changes: 10 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
sqrt_num = [int(math.sqrt(_num)) for _num in num]
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
return ensure_tuple(ret) == num


def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(...,) + (None,) * (ndim - arr.ndim)]


def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]
72 changes: 72 additions & 0 deletions tests/test_component_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

from monai.utils import ComponentStore


class TestComponentStore(unittest.TestCase):
def setUp(self):
self.cs = ComponentStore("TestStore", "I am a test store, please ignore")

def test_empty(self):
self.assertEqual(len(self.cs), 0)
self.assertEqual(list(self.cs), [])

def test_add(self):
test_obj = object()

self.assertFalse("test_obj" in self.cs)

self.cs.add("test_obj", "Test object", test_obj)

self.assertTrue("test_obj" in self.cs)

self.assertEqual(len(self.cs), 1)
self.assertEqual(list(self.cs), [("test_obj", test_obj)])

self.assertEqual(self.cs.test_obj, test_obj)
self.assertEqual(self.cs["test_obj"], test_obj)

def test_add2(self):
test_obj1 = object()
test_obj2 = object()

self.cs.add("test_obj1", "Test object", test_obj1)
self.cs.add("test_obj2", "Test object", test_obj2)

self.assertEqual(len(self.cs), 2)
self.assertTrue("test_obj1" in self.cs)
self.assertTrue("test_obj2" in self.cs)

def test_add_def(self):
self.assertFalse("test_func" in self.cs)

@self.cs.add_def("test_func", "Test function")
def test_func():
return 123

self.assertTrue("test_func" in self.cs)

self.assertEqual(len(self.cs), 1)
self.assertEqual(list(self.cs), [("test_func", test_func)])

self.assertEqual(self.cs.test_func, test_func)
self.assertEqual(self.cs["test_func"], test_func)

# try adding the same function again
self.cs.add_def("test_func", "Test function but with new description")(test_func)

self.assertEqual(len(self.cs), 1)
self.assertEqual(self.cs.test_func, test_func)
47 changes: 47 additions & 0 deletions tests/test_squeeze_unsqueeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.utils import unsqueeze_left, unsqueeze_right

RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))]
marksgraham marked this conversation as resolved.
Show resolved Hide resolved

LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))]

ALL_CASES = [
(np.random.rand(3, 4), 2, (3, 4)),
(np.random.rand(3, 4), 0, (3, 4)),
(np.random.rand(3, 4), -1, (3, 4)),
(np.array(3), 4, (1, 1, 1, 1)),
(np.array(3), 0, ()),
(torch.rand(3, 4), 2, (3, 4)),
(torch.rand(3, 4), 0, (3, 4)),
(torch.rand(3, 4), -1, (3, 4)),
(torch.tensor(3), 4, (1, 1, 1, 1)),
(torch.tensor(3), 0, ()),
]


class TestUnsqueeze(unittest.TestCase):
@parameterized.expand(RIGHT_CASES + ALL_CASES)
def test_unsqueeze_right(self, arr, ndim, shape):
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)

@parameterized.expand(LEFT_CASES + ALL_CASES)
def test_unsqueeze_left(self, arr, ndim, shape):
self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)