Skip to content

Commit

Permalink
[WIP] 7145 common factory class (#7159)
Browse files Browse the repository at this point in the history
Fixes #7145

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham authored Oct 27, 2023
1 parent 487f98b commit ccd32ca
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 64 deletions.
235 changes: 171 additions & 64 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,40 +68,40 @@ def use_factory(fact_args):
import torch.nn as nn

from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import look_up_option, optional_import
from monai.utils import ComponentStore, look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


class LayerFactory:
class LayerFactory(ComponentStore):
"""
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
callables. These functions are referred to by name and can be added at any time.
"""

def __init__(self) -> None:
self.factories: dict[str, Callable] = {}
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self.__doc__ = (
f"Layer Factory '{name}': {description}\n".strip()
+ "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
+ "\n\nThe supported members are:"
)

@property
def names(self) -> tuple[str, ...]:
def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None:
"""
Produces all factory names.
Add the factory function to this object under the given name, with optional description.
"""
description: str = desc or func.__doc__ or ""
self.add(name.upper(), description, func)
# append name to the docstring
assert self.__doc__ is not None
self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``"

return tuple(self.factories)

def add_factory_callable(self, name: str, func: Callable) -> None:
def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:
"""
Add the factory function to this object under the given name.
Adds a factory function which returns the supplied class under the given name, with optional description.
"""

self.factories[name.upper()] = func
self.__doc__ = (
"The supported member"
+ ("s are: " if len(self.names) > 1 else " is: ")
+ ", ".join(f"``{name}``" for name in self.names)
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)
self.add_factory_callable(name, lambda x=None: cls, desc)

def factory_function(self, name: str) -> Callable:
"""
Expand All @@ -126,8 +126,9 @@ def get_constructor(self, factory_name: str, *args) -> Any:
if not isinstance(factory_name, str):
raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")

func = look_up_option(factory_name.upper(), self.factories)
return func(*args)
component = look_up_option(factory_name.upper(), self.components)

return component.value(*args)

def __getitem__(self, args) -> Any:
"""
Expand All @@ -153,7 +154,7 @@ def __getattr__(self, key):
as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
"""

if key in self.factories:
if key in self.components:
return key

return super().__getattribute__(key)
Expand Down Expand Up @@ -194,56 +195,60 @@ def split_args(args):


# Define factories for these layer types

Dropout = LayerFactory()
Norm = LayerFactory()
Act = LayerFactory()
Conv = LayerFactory()
Pool = LayerFactory()
Pad = LayerFactory()
Dropout = LayerFactory(name="Dropout layers", description="Factory for creating dropout layers.")
Norm = LayerFactory(name="Normalization layers", description="Factory for creating normalization layers.")
Act = LayerFactory(name="Activation layers", description="Factory for creating activation layers.")
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")


@Dropout.factory_function("dropout")
def dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:
"""
Dropout layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the dropout layer
Returns:
Dropout[dim]d
"""
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
return types[dim - 1]


@Dropout.factory_function("alphadropout")
def alpha_dropout_factory(_dim):
return nn.AlphaDropout
Dropout.add_factory_class("alphadropout", nn.AlphaDropout)


@Norm.factory_function("instance")
def instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:
"""
Instance normalization layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the instance normalization layer
Returns:
InstanceNorm[dim]d
"""
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
return types[dim - 1]


@Norm.factory_function("batch")
def batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("group")
def group_factory(_dim) -> type[nn.GroupNorm]:
return nn.GroupNorm


@Norm.factory_function("layer")
def layer_factory(_dim) -> type[nn.LayerNorm]:
return nn.LayerNorm


@Norm.factory_function("localresponse")
def local_response_factory(_dim) -> type[nn.LocalResponseNorm]:
return nn.LocalResponseNorm
"""
Batch normalization layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the batch normalization layer
@Norm.factory_function("syncbatch")
def sync_batch_factory(_dim) -> type[nn.SyncBatchNorm]:
return nn.SyncBatchNorm
Returns:
BatchNorm[dim]d
"""
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
return types[dim - 1]


@Norm.factory_function("instance_nvfuser")
Expand Down Expand Up @@ -274,91 +279,193 @@ def instance_nvfuser_factory(dim):
return optional_import("apex.normalization", name="InstanceNorm3dNVFuser")[0]


Act.add_factory_callable("elu", lambda: nn.modules.ELU)
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
Norm.add_factory_class("group", nn.GroupNorm)
Norm.add_factory_class("layer", nn.LayerNorm)
Norm.add_factory_class("localresponse", nn.LocalResponseNorm)
Norm.add_factory_class("syncbatch", nn.SyncBatchNorm)


Act.add_factory_class("elu", nn.modules.ELU)
Act.add_factory_class("relu", nn.modules.ReLU)
Act.add_factory_class("leakyrelu", nn.modules.LeakyReLU)
Act.add_factory_class("prelu", nn.modules.PReLU)
Act.add_factory_class("relu6", nn.modules.ReLU6)
Act.add_factory_class("selu", nn.modules.SELU)
Act.add_factory_class("celu", nn.modules.CELU)
Act.add_factory_class("gelu", nn.modules.GELU)
Act.add_factory_class("sigmoid", nn.modules.Sigmoid)
Act.add_factory_class("tanh", nn.modules.Tanh)
Act.add_factory_class("softmax", nn.modules.Softmax)
Act.add_factory_class("logsoftmax", nn.modules.LogSoftmax)


@Act.factory_function("swish")
def swish_factory():
"""
Swish activation layer.
Returns:
Swish
"""
from monai.networks.blocks.activation import Swish

return Swish


@Act.factory_function("memswish")
def memswish_factory():
"""
Memory efficient swish activation layer.
Returns:
MemoryEfficientSwish
"""
from monai.networks.blocks.activation import MemoryEfficientSwish

return MemoryEfficientSwish


@Act.factory_function("mish")
def mish_factory():
"""
Mish activation layer.
Returns:
Mish
"""
from monai.networks.blocks.activation import Mish

return Mish


@Act.factory_function("geglu")
def geglu_factory():
"""
GEGLU activation layer.
Returns:
GEGLU
"""
from monai.networks.blocks.activation import GEGLU

return GEGLU


@Conv.factory_function("conv")
def conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:
"""
Convolutional layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the convolutional layer
Returns:
Conv[dim]d
"""
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
return types[dim - 1]


@Conv.factory_function("convtrans")
def convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:
"""
Transposed convolutional layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the transposed convolutional layer
Returns:
ConvTranspose[dim]d
"""
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
return types[dim - 1]


@Pool.factory_function("max")
def maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:
"""
Max pooling layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the max pooling layer
Returns:
MaxPool[dim]d
"""
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
return types[dim - 1]


@Pool.factory_function("adaptivemax")
def adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:
"""
Adaptive max pooling layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the adaptive max pooling layer
Returns:
AdaptiveMaxPool[dim]d
"""
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
return types[dim - 1]


@Pool.factory_function("avg")
def avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:
"""
Average pooling layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the average pooling layer
Returns:
AvgPool[dim]d
"""
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
return types[dim - 1]


@Pool.factory_function("adaptiveavg")
def adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:
"""
Adaptive average pooling layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the adaptive average pooling layer
Returns:
AdaptiveAvgPool[dim]d
"""
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
return types[dim - 1]


@Pad.factory_function("replicationpad")
def replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:
"""
Replication padding layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the replication padding layer
Returns:
ReplicationPad[dim]d
"""
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
return types[dim - 1]


@Pad.factory_function("constantpad")
def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:
"""
Constant padding layers in 1,2,3 dimensions.
Args:
dim: desired dimension of the constant padding layer
Returns:
ConstantPad[dim]d
"""
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# have to explicitly bring these in here to resolve circular import issues
from .aliases import alias, resolve_name
from .component_store import ComponentStore
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
Expand Down
Loading

0 comments on commit ccd32ca

Please sign in to comment.