From ccd32ca5e9e84562d2f388b45b6724b5c77c1f57 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 27 Oct 2023 11:24:34 -0500 Subject: [PATCH] [WIP] 7145 common factory class (#7159) Fixes https://github.com/Project-MONAI/MONAI/issues/7145 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham --- monai/networks/layers/factories.py | 235 +++++++++++++++++++++-------- monai/utils/__init__.py | 1 + monai/utils/component_store.py | 125 +++++++++++++++ tests/test_component_store.py | 72 +++++++++ 4 files changed, 369 insertions(+), 64 deletions(-) create mode 100644 monai/utils/component_store.py create mode 100644 tests/test_component_store.py diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index bb56b0c0c5..38ee68cbee 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -68,40 +68,40 @@ def use_factory(fact_args): import torch.nn as nn from monai.networks.utils import has_nvfuser_instance_norm -from monai.utils import look_up_option, optional_import +from monai.utils import ComponentStore, look_up_option, optional_import __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] -class LayerFactory: +class LayerFactory(ComponentStore): """ Factory object for creating layers, this uses given factory functions to actually produce the types or constructing callables. These functions are referred to by name and can be added at any time. """ - def __init__(self) -> None: - self.factories: dict[str, Callable] = {} + def __init__(self, name: str, description: str) -> None: + super().__init__(name, description) + self.__doc__ = ( + f"Layer Factory '{name}': {description}\n".strip() + + "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." + + "\n\nThe supported members are:" + ) - @property - def names(self) -> tuple[str, ...]: + def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None: """ - Produces all factory names. + Add the factory function to this object under the given name, with optional description. """ + description: str = desc or func.__doc__ or "" + self.add(name.upper(), description, func) + # append name to the docstring + assert self.__doc__ is not None + self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``" - return tuple(self.factories) - - def add_factory_callable(self, name: str, func: Callable) -> None: + def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None: """ - Add the factory function to this object under the given name. + Adds a factory function which returns the supplied class under the given name, with optional description. """ - - self.factories[name.upper()] = func - self.__doc__ = ( - "The supported member" - + ("s are: " if len(self.names) > 1 else " is: ") - + ", ".join(f"``{name}``" for name in self.names) - + ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing." - ) + self.add_factory_callable(name, lambda x=None: cls, desc) def factory_function(self, name: str) -> Callable: """ @@ -126,8 +126,9 @@ def get_constructor(self, factory_name: str, *args) -> Any: if not isinstance(factory_name, str): raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") - func = look_up_option(factory_name.upper(), self.factories) - return func(*args) + component = look_up_option(factory_name.upper(), self.components) + + return component.value(*args) def __getitem__(self, args) -> Any: """ @@ -153,7 +154,7 @@ def __getattr__(self, key): as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo. """ - if key in self.factories: + if key in self.components: return key return super().__getattribute__(key) @@ -194,56 +195,60 @@ def split_args(args): # Define factories for these layer types - -Dropout = LayerFactory() -Norm = LayerFactory() -Act = LayerFactory() -Conv = LayerFactory() -Pool = LayerFactory() -Pad = LayerFactory() +Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.") +Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.") +Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.") +Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.") +Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.") +Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.") @Dropout.factory_function("dropout") def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]: + """ + Dropout layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the dropout layer + + Returns: + Dropout[dim]d + """ types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d) return types[dim - 1] -@Dropout.factory_function("alphadropout") -def alpha_dropout_factory(_dim): - return nn.AlphaDropout +Dropout.add_factory_class("alphadropout", nn.AlphaDropout) @Norm.factory_function("instance") def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]: + """ + Instance normalization layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the instance normalization layer + + Returns: + InstanceNorm[dim]d + """ types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d) return types[dim - 1] @Norm.factory_function("batch") def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]: - types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) - return types[dim - 1] - - -@Norm.factory_function("group") -def group_factory(_dim) -> type[nn.GroupNorm]: - return nn.GroupNorm - - -@Norm.factory_function("layer") -def layer_factory(_dim) -> type[nn.LayerNorm]: - return nn.LayerNorm - - -@Norm.factory_function("localresponse") -def local_response_factory(_dim) -> type[nn.LocalResponseNorm]: - return nn.LocalResponseNorm + """ + Batch normalization layers in 1,2,3 dimensions. + Args: + dim: desired dimension of the batch normalization layer -@Norm.factory_function("syncbatch") -def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]: - return nn.SyncBatchNorm + Returns: + BatchNorm[dim]d + """ + types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + return types[dim - 1] @Norm.factory_function("instance_nvfuser") @@ -274,22 +279,34 @@ def instance_nvfuser_factory(dim): return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0] -Act.add_factory_callable("elu", lambda: nn.modules.ELU) -Act.add_factory_callable("relu", lambda: nn.modules.ReLU) -Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU) -Act.add_factory_callable("prelu", lambda: nn.modules.PReLU) -Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6) -Act.add_factory_callable("selu", lambda: nn.modules.SELU) -Act.add_factory_callable("celu", lambda: nn.modules.CELU) -Act.add_factory_callable("gelu", lambda: nn.modules.GELU) -Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid) -Act.add_factory_callable("tanh", lambda: nn.modules.Tanh) -Act.add_factory_callable("softmax", lambda: nn.modules.Softmax) -Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax) +Norm.add_factory_class("group", nn.GroupNorm) +Norm.add_factory_class("layer", nn.LayerNorm) +Norm.add_factory_class("localresponse", nn.LocalResponseNorm) +Norm.add_factory_class("syncbatch", nn.SyncBatchNorm) + + +Act.add_factory_class("elu", nn.modules.ELU) +Act.add_factory_class("relu", nn.modules.ReLU) +Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU) +Act.add_factory_class("prelu", nn.modules.PReLU) +Act.add_factory_class("relu6", nn.modules.ReLU6) +Act.add_factory_class("selu", nn.modules.SELU) +Act.add_factory_class("celu", nn.modules.CELU) +Act.add_factory_class("gelu", nn.modules.GELU) +Act.add_factory_class("sigmoid", nn.modules.Sigmoid) +Act.add_factory_class("tanh", nn.modules.Tanh) +Act.add_factory_class("softmax", nn.modules.Softmax) +Act.add_factory_class("logsoftmax", nn.modules.LogSoftmax) @Act.factory_function("swish") def swish_factory(): + """ + Swish activation layer. + + Returns: + Swish + """ from monai.networks.blocks.activation import Swish return Swish @@ -297,6 +314,12 @@ def swish_factory(): @Act.factory_function("memswish") def memswish_factory(): + """ + Memory efficient swish activation layer. + + Returns: + MemoryEfficientSwish + """ from monai.networks.blocks.activation import MemoryEfficientSwish return MemoryEfficientSwish @@ -304,6 +327,12 @@ def memswish_factory(): @Act.factory_function("mish") def mish_factory(): + """ + Mish activation layer. + + Returns: + Mish + """ from monai.networks.blocks.activation import Mish return Mish @@ -311,6 +340,12 @@ def mish_factory(): @Act.factory_function("geglu") def geglu_factory(): + """ + GEGLU activation layer. + + Returns: + GEGLU + """ from monai.networks.blocks.activation import GEGLU return GEGLU @@ -318,47 +353,119 @@ def geglu_factory(): @Conv.factory_function("conv") def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]: + """ + Convolutional layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the convolutional layer + + Returns: + Conv[dim]d + """ types = (nn.Conv1d, nn.Conv2d, nn.Conv3d) return types[dim - 1] @Conv.factory_function("convtrans") def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]: + """ + Transposed convolutional layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the transposed convolutional layer + + Returns: + ConvTranspose[dim]d + """ types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) return types[dim - 1] @Pool.factory_function("max") def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]: + """ + Max pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the max pooling layer + + Returns: + MaxPool[dim]d + """ types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d) return types[dim - 1] @Pool.factory_function("adaptivemax") def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]: + """ + Adaptive max pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the adaptive max pooling layer + + Returns: + AdaptiveMaxPool[dim]d + """ types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d) return types[dim - 1] @Pool.factory_function("avg") def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]: + """ + Average pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the average pooling layer + + Returns: + AvgPool[dim]d + """ types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d) return types[dim - 1] @Pool.factory_function("adaptiveavg") def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]: + """ + Adaptive average pooling layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the adaptive average pooling layer + + Returns: + AdaptiveAvgPool[dim]d + """ types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d) return types[dim - 1] @Pad.factory_function("replicationpad") def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]: + """ + Replication padding layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the replication padding layer + + Returns: + ReplicationPad[dim]d + """ types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d) return types[dim - 1] @Pad.factory_function("constantpad") def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]: + """ + Constant padding layers in 1,2,3 dimensions. + + Args: + dim: desired dimension of the constant padding layer + + Returns: + ConstantPad[dim]d + """ types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d) return types[dim - 1] diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c973d4bfa1..82f944ccb8 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -13,6 +13,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name +from .component_store import ComponentStore from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py new file mode 100644 index 0000000000..d1e71eaebf --- /dev/null +++ b/monai/utils/component_store.py @@ -0,0 +1,125 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import namedtuple +from keyword import iskeyword +from textwrap import dedent, indent +from typing import Any, Callable, Iterable, TypeVar + +T = TypeVar("T") + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore: + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + + """ + + _Component = namedtuple("_Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: dict[str, ComponentStore._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + @property + def names(self) -> tuple[str, ...]: + """ + Produces all factory names. + """ + + return tuple(self.components) + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__") and v.value.__doc__: + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" + + return result + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 0000000000..614f387754 --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func)