Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/callable protocols 2 #599

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ The semantic versioning only considers the public API as described in
:ref:`api-ref`. Components not mentioned in :ref:`api-ref` or different import
paths are considered internals and can change in minor and patch releases.

v4.34.1 (2024-10-??)
--------------------

Fixed
^^^^^
- Fix callable protocol inheritance.
Comment on lines +17 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v4.34.0 hasn't been released yet, so this entry should be there instead of a new v4.34.1 section. Also please add a link to this pull request.



v4.34.0 (2024-10-??)
--------------------
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,14 +1101,15 @@ def adapt_typehints(


def implements_protocol(value, protocol) -> bool:
allowed_dunder_methods = {"__call__"}
Comment on lines 1102 to +1104
Copy link
Member

@mauvilsa mauvilsa Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def implements_protocol(value, protocol) -> bool:
allowed_dunder_methods = {"__call__"}
protocol_irrelevant_dunder_methods = {
"__init__", "__new__", "__del__", "__getattr__", "__getattribute__", "__setattr__",
"__delattr__", "__reduce__", "__reduce_ex__", "__getstate__", "__setstate__"
}
def implements_protocol(value, protocol) -> bool:

Better to have the set as a global, instead of inside the function.

from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value):
if not inspect.isclass(value): # Should we check if callables implement protocols?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does make sense that a function could implement a callable protocol. But I would leave this for a different pull request.

return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
if name.startswith("_"):
if name.startswith("_") and name not in allowed_dunder_methods:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give me some time and I analyze the dunder methods in general and maybe in this pull request we do more than __call__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if name.startswith("_") and name not in allowed_dunder_methods:
is_dunder = name.startswith("__") and name.endswith("__")
if (not is_dunder and name.startswith("_")) or (is_dunder and name in protocol_irrelevant_dunder_methods):

This is what I came up with.

continue
if not hasattr(value, name):
return False
Expand Down
55 changes: 55 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,61 @@
ctx.match("Not a valid subclass of Interface")


# callable protocol tests


class CallableInterface(Protocol):
def __call__(self, items: List[float]) -> List[float]: ...
Dismissed Show dismissed Hide dismissed


class ImplementsCallableInterface1:
def __init__(self, batch_size: int):
self.batch_size = batch_size

def __call__(self, items: List[float]) -> List[float]:
return items


def implements_callable_interface2(items: List[float]) -> List[float]:
return items


class NotImplementsCallableInterface1:
def __call__(self, items: str) -> List[float]:
return []


class NotImplementsCallableInterface2:
def __call__(self, items: List[float], extra: int) -> List[float]:
return items


class NotImplementsCallableInterface3:
def __call__(self, items: List[float]) -> None:
return


def not_implements_callable_interface4(items: str) -> List[float]:
return []


@pytest.mark.parametrize(
"expected, value",
[
(True, ImplementsCallableInterface1),
(False, ImplementsCallableInterface1(1)),
(True, implements_callable_interface2),
(False, NotImplementsCallableInterface1),
(False, NotImplementsCallableInterface2),
(False, NotImplementsCallableInterface3),
(False, not_implements_callable_interface4),
(False, object),
],
)
def test_implements_callable_protocol(expected, value):
assert implements_protocol(value, CallableInterface) is expected


# parameter skip tests


Expand Down
Loading