-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add functional API for algorithm contributtions #876
base: develop
Are you sure you want to change the base?
Changes from all commits
dd0bee5
423288f
0799fef
350a363
526d8f1
42f7739
47b4d4f
2f8c470
a858a74
3f0f9dc
2525ce4
94bfd79
02e7763
fb941dd
077f13a
77adef1
bebab4c
2beb83c
c0a3770
e7480a5
413ead0
ef72b97
05e2a86
7749300
ac7e352
2d75267
be6bef0
fb1a0db
a97ceef
d555168
ad9946a
a104055
28b697a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,10 @@ | |
from PartSegCore.utils import BaseModel | ||
from PartSegImage import Channel | ||
|
||
T = typing.TypeVar("T", bound="AlgorithmDescribeBase") | ||
|
||
TypeT = typing.Type[T] | ||
|
||
|
||
class AlgorithmDescribeNotFound(Exception): | ||
""" | ||
|
@@ -116,7 +120,7 @@ def _partial_abstractmethod(funcobj): | |
|
||
|
||
class AlgorithmDescribeBaseMeta(ABCMeta): | ||
def __new__(cls, name, bases, attrs, **kwargs): | ||
def __new__(cls, name, bases, attrs, method_from_fun=None, additional_parameters=None, **kwargs): | ||
cls2 = super().__new__(cls, name, bases, attrs, **kwargs) | ||
if ( | ||
not inspect.isabstract(cls2) | ||
|
@@ -125,8 +129,181 @@ def __new__(cls, name, bases, attrs, **kwargs): | |
): | ||
raise RuntimeError("class need to have __argument_class__ set or get_fields functions defined") | ||
cls2.__new_style__ = getattr(cls2.get_fields, "__is_partial_abstractmethod__", False) | ||
cls2.__from_function__ = getattr(cls2, "__from_function__", False) | ||
cls2.__abstract_getters__ = {} | ||
cls2.__method_name__ = method_from_fun or getattr(cls2, "__method_name__", None) | ||
cls2.__additional_parameters_name__ = additional_parameters or getattr( | ||
cls2, "__additional_parameters_name__", None | ||
) | ||
if cls2.__additional_parameters_name__ is None: | ||
cls2.__additional_parameters_name__ = cls._get_calculation_method_params_name(cls2) | ||
|
||
cls2.__support_from_function__ = ( | ||
cls2.__method_name__ is not None and cls2.__additional_parameters_name__ is not None | ||
) | ||
|
||
cls2.__abstract_getters__, cls2.__support_from_function__ = cls._get_abstract_getters( | ||
cls2, cls2.__support_from_function__, method_from_fun | ||
) | ||
return cls2 | ||
|
||
@staticmethod | ||
def _get_abstract_getters( | ||
cls2, support_from_function, calculation_method | ||
) -> typing.Tuple[typing.Dict[str, typing.Any], bool]: | ||
abstract_getters = {} | ||
if hasattr(cls2, "__abstractmethods__") and cls2.__abstractmethods__: | ||
# get all abstract methods that starts with `get_` | ||
for method_name in cls2.__abstractmethods__: | ||
if method_name.startswith("get_"): | ||
method = getattr(cls2, method_name) | ||
if "return" not in method.__annotations__: | ||
msg = f"Method {method_name} of {cls2.__qualname__} need to have return type defined" | ||
try: | ||
file_name = inspect.getsourcefile(method) | ||
line = inspect.getsourcelines(method)[1] | ||
msg += f" in {file_name}:{line}" | ||
except TypeError: | ||
pass | ||
raise RuntimeError(msg) | ||
|
||
abstract_getters[method_name[4:]] = getattr(cls2, method_name).__annotations__["return"] | ||
elif method_name != calculation_method: | ||
support_from_function = False | ||
return abstract_getters, support_from_function | ||
|
||
@staticmethod | ||
def _get_calculation_method_params_name(cls2) -> typing.Optional[str]: | ||
if cls2.__method_name__ is None: | ||
return None | ||
signature = inspect.signature(getattr(cls2, cls2.__method_name__)) | ||
if "arguments" in signature.parameters: | ||
return "arguments" | ||
if "params" in signature.parameters: | ||
return "params" | ||
if "parameters" in signature.parameters: | ||
return "parameters" | ||
raise RuntimeError(f"Cannot determine arguments parameter name in {cls2.__method_name__}") | ||
Comment on lines
+175
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Enhance the method to support a wider range of parameter naming conventions or provide a mechanism for developers to specify the parameter name explicitly, improving the API's flexibility. |
||
|
||
@staticmethod | ||
def _validate_if_all_abstract_getters_are_defined(abstract_getters, kwargs): | ||
abstract_getters_set = set(abstract_getters) | ||
kwargs_set = set(kwargs.keys()) | ||
|
||
if abstract_getters_set != kwargs_set: | ||
# Provide a nice error message with information about what is missing and is obsolete | ||
missing_text = ", ".join(sorted(abstract_getters_set - kwargs_set)) | ||
if missing_text: | ||
missing_text = f"Not all abstract methods are provided, missing: {missing_text}." | ||
else: | ||
missing_text = "" | ||
extra_text = ", ".join(sorted(kwargs_set - abstract_getters_set)) | ||
if extra_text: | ||
extra_text = f"There are extra attributes in call: {extra_text}." | ||
else: | ||
extra_text = "" | ||
|
||
raise ValueError(f"{missing_text} {extra_text}") | ||
Comment on lines
+188
to
+206
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Improve the error message generated by |
||
|
||
@staticmethod | ||
def _validate_function_parameters(func, method, method_name) -> set: | ||
""" | ||
Validate if all parameters without default values are defined in self.__calculation_method__ | ||
|
||
:param func: function to validate | ||
:return: set of parameters that should be dropped | ||
""" | ||
signature = inspect.signature(func) | ||
base_method_signature = inspect.signature(method) | ||
take_all = False | ||
|
||
for parameter in signature.parameters.values(): | ||
if parameter.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY}: | ||
raise ValueError(f"Function {func} should not have positional only parameters") | ||
if ( | ||
parameter.default is inspect.Parameter.empty | ||
and parameter.name not in base_method_signature.parameters | ||
and parameter.kind != inspect.Parameter.VAR_KEYWORD | ||
): | ||
raise ValueError(f"Parameter {parameter.name} is not defined in {method_name} method") | ||
|
||
if parameter.kind == inspect.Parameter.VAR_KEYWORD: | ||
take_all = True | ||
|
||
if take_all: | ||
return set() | ||
|
||
return { | ||
parameters.name | ||
for parameters in base_method_signature.parameters.values() | ||
Comment on lines
+208
to
+238
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Enhance the error handling in |
||
if parameters.name not in signature.parameters | ||
} | ||
|
||
@staticmethod | ||
def _get_argument_class_from_signature(func, argument_name: str): | ||
signature = inspect.signature(func) | ||
if argument_name not in signature.parameters: | ||
return BaseModel | ||
return signature.parameters[argument_name].annotation | ||
|
||
@staticmethod | ||
def _get_parameters_from_signature(func): | ||
signature = inspect.signature(func) | ||
return [parameters.name for parameters in signature.parameters.values()] | ||
|
||
def from_function(self, func=None, **kwargs): | ||
"""generate new class from function""" | ||
|
||
# Test if all abstract methods values are provided in kwargs | ||
|
||
if not self.__support_from_function__: | ||
raise RuntimeError("This class does not support from_function method") | ||
|
||
self._validate_if_all_abstract_getters_are_defined(self.__abstract_getters__, kwargs) | ||
|
||
# check if all values have correct type | ||
for key, value in kwargs.items(): | ||
if not isinstance(value, self.__abstract_getters__[key]): | ||
raise TypeError(f"Value for {key} should be {self.__abstract_getters__[key]}") | ||
|
||
def _getter_by_name(name): | ||
def _func(): | ||
return kwargs[name] | ||
|
||
return _func | ||
|
||
parameters_order = self._get_parameters_from_signature(getattr(self, self.__method_name__)) | ||
|
||
def _class_generator(func_): | ||
drop_attr = self._validate_function_parameters( | ||
func_, getattr(self, self.__method_name__), self.__method_name__ | ||
) | ||
|
||
@wraps(func_) | ||
def _calculate_method(*args, **kwargs_): | ||
for attr, name in zip(args, parameters_order): | ||
if name in kwargs_: | ||
raise ValueError(f"Parameter {name} is defined twice") | ||
kwargs_[name] = attr | ||
|
||
for name in drop_attr: | ||
kwargs_.pop(name, None) | ||
return func_(**kwargs_) | ||
|
||
class_dkt = {f"get_{name}": _getter_by_name(name) for name in self.__abstract_getters__} | ||
|
||
class_dkt[self.__method_name__] = _calculate_method | ||
class_dkt["__argument_class__"] = self._get_argument_class_from_signature( | ||
func_, self.__additional_parameters_name__ | ||
) | ||
class_dkt["__from_function__"] = True | ||
|
||
return type(func_.__name__.replace("_", " ").title().replace(" ", ""), (self,), class_dkt) | ||
|
||
if func is None: | ||
return _class_generator | ||
return _class_generator(func) | ||
Comment on lines
+254
to
+305
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Consider refactoring the |
||
|
||
|
||
class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): | ||
""" | ||
Comment on lines
129
to
309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The modifications to the Consider refining the logic that determines the support for |
||
|
@@ -138,6 +315,11 @@ class AlgorithmDescribeBase(ABC, metaclass=AlgorithmDescribeBaseMeta): | |
__argument_class__: typing.Optional[typing.Type[PydanticBaseModel]] = None | ||
__new_style__: bool | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if cls.__from_function__: | ||
return getattr(cls, cls.__method_name__)(*args, **kwargs) | ||
return super().__new__(cls) | ||
Comment on lines
+318
to
+321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Evaluate the impact of the modified |
||
|
||
@classmethod | ||
def get_doc_from_fields(cls): | ||
resp = "{\n" | ||
|
@@ -150,6 +332,29 @@ def get_doc_from_fields(cls): | |
resp += "}\n" | ||
return resp | ||
|
||
@classmethod | ||
@typing.overload | ||
def from_function(cls: TypeT, func: typing.Callable[..., typing.Any], **kwargs) -> TypeT: | ||
... | ||
|
||
@classmethod | ||
@typing.overload | ||
def from_function(cls: TypeT, **kwargs) -> typing.Callable[[typing.Callable[..., typing.Any]], TypeT]: | ||
... | ||
|
||
@classmethod | ||
def from_function( | ||
cls: TypeT, func=None, **kwargs | ||
) -> typing.Union[TypeT, typing.Callable[[typing.Callable], TypeT]]: | ||
def _from_function(func_) -> typing.Type["AlgorithmDescribeBase"]: | ||
if "name" not in kwargs: | ||
kwargs["name"] = func_.__name__.replace("_", " ").title() | ||
return AlgorithmDescribeBaseMeta.from_function(cls, func_, **kwargs) | ||
|
||
if func is None: | ||
return _from_function | ||
return _from_function(func) | ||
Comment on lines
+345
to
+356
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Enhance the documentation for the |
||
|
||
@classmethod | ||
@abstractmethod | ||
def get_name(cls) -> str: | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -68,7 +68,7 @@ class SaveBase(AlgorithmDescribeBase, ABC): | |||||
|
||||||
@classmethod | ||||||
@abstractmethod | ||||||
def get_short_name(cls): | ||||||
def get_short_name(cls) -> str: | ||||||
raise NotImplementedError | ||||||
|
||||||
@classmethod | ||||||
|
@@ -102,10 +102,12 @@ def get_default_extension(cls): | |||||
|
||||||
@classmethod | ||||||
def need_segmentation(cls): | ||||||
"""If method requires segmentation (ROI) to work, or could work with image only""" | ||||||
return True | ||||||
|
||||||
@classmethod | ||||||
def need_mask(cls): | ||||||
"""If `mask` is required for perform save""" | ||||||
return False | ||||||
|
||||||
@classmethod | ||||||
|
@@ -132,7 +134,7 @@ class LoadBase(AlgorithmDescribeBase, ABC): | |||||
|
||||||
@classmethod | ||||||
@abstractmethod | ||||||
def get_short_name(cls): | ||||||
def get_short_name(cls) -> str: | ||||||
raise NotImplementedError | ||||||
|
||||||
@classmethod | ||||||
|
@@ -161,8 +163,7 @@ def get_name_with_suffix(cls): | |||||
|
||||||
@classmethod | ||||||
def get_extensions(cls) -> typing.List[str]: | ||||||
match = re.match(r".*\((.*)\)", cls.get_name()) | ||||||
if match is None: | ||||||
if (match := re.match(r".*\((.*)\)", cls.get_name())) is None: | ||||||
raise ValueError(f"No extensions found in {cls.get_name()}") | ||||||
extensions = match[1].split(" ") | ||||||
if not all(x.startswith("*.") for x in extensions): | ||||||
|
@@ -205,7 +206,7 @@ def load_metadata_base(data: typing.Union[str, Path]): | |||||
try: | ||||||
decoded_data = json.loads(str(data), object_hook=partseg_object_hook) | ||||||
except Exception: # pragma: no cover | ||||||
raise e # noqa: B904 | ||||||
raise e from None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying exception handling in - raise e from None
+ raise e Committable suggestion
Suggested change
|
||||||
|
||||||
return decoded_data | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_get_abstract_getters
method has been modified to support the new functionality related to creating classes from functions. This method now checks for abstract methods starting withget_
and validates their return type annotations. While this approach ensures that abstract getters are properly defined, it may be too restrictive by excluding potentially valid abstract methods that do not follow this naming convention. Additionally, the error message generated when a return type is not defined could be enhanced to provide more specific guidance on how to correct the issue.Consider allowing more flexibility in the naming convention of abstract methods and improve the error message for missing return type annotations to offer clearer guidance on resolving the issue.