Skip to content

Commit

Permalink
Robust access of typing subscripts.
Browse files Browse the repository at this point in the history
Use functions `get_origin` and `get_args` instead of direct attributes. Refs #107.
  • Loading branch information
coady committed Dec 29, 2023
1 parent 7bf967a commit bbccb81
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 14 deletions.
2 changes: 0 additions & 2 deletions docs/requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
mkdocs-material
mkdocstrings[python]
mkdocs-jupyter
ipykernel
33 changes: 22 additions & 11 deletions multimethod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import inspect
import itertools
import types
import typing
from collections.abc import Callable, Iterable, Iterator, Mapping
from typing import Any, Literal, Optional, TypeVar, Union
from typing import get_type_hints, overload as tp_overload
from typing import Any, Literal, Optional, TypeVar, Union, get_type_hints, overload as tp_overload

__version__ = '1.10'

Expand All @@ -16,6 +16,16 @@ class DispatchError(TypeError):
pass


def get_origin(tp):
return tp.__origin__ if isinstance(tp, subtype) else typing.get_origin(tp)


def get_args(tp) -> tuple:
if isinstance(tp, subtype) or typing.get_origin(tp) is Callable:
return tp.__args__
return typing.get_args(tp)


class subtype(abc.ABCMeta):
"""A normalized generic type which checks subscripts.
Expand All @@ -35,10 +45,10 @@ def __new__(cls, tp, *args):
if not tp.__constraints__:
return object
tp = Union[tp.__constraints__]
origin = getattr(tp, '__origin__', tp)
origin = get_origin(tp) or tp
if hasattr(types, 'UnionType') and isinstance(tp, types.UnionType):
origin = Union # `|` syntax added in 3.10
args = tuple(map(cls, getattr(tp, '__args__', args)))
args = tuple(map(cls, get_args(tp) or args))
if set(args) <= {object} and not (origin is tuple and args):
return origin
bases = (origin,) if type(origin) in (type, abc.ABCMeta) else ()
Expand All @@ -61,8 +71,8 @@ def __hash__(self) -> int:
return hash(self.key())

def __subclasscheck__(self, subclass):
origin = getattr(subclass, '__origin__', subclass)
args = getattr(subclass, '__args__', ())
origin = get_origin(subclass) or subclass
args = get_args(subclass)
if origin is Literal:
return all(isinstance(arg, self) for arg in args)
if origin is Union:
Expand Down Expand Up @@ -107,19 +117,19 @@ def __instancecheck__(self, instance):

def origins(self) -> Iterator[type]:
"""Generate origins which would need subscript checking."""
origin = getattr(self, '__origin__', None) # also called as a staticmethod
origin = get_origin(self)
if origin is Literal:
yield from set(map(type, self.__args__))
elif origin is Union:
for cls in self.__args__:
yield from subtype.origins(cls) # type: ignore
yield from subtype.origins(cls)
elif origin is not None:
yield origin


def distance(cls, subclass: type) -> int:
"""Return estimated distance between classes for tie-breaking."""
if getattr(cls, '__origin__', None) is Union:
if get_origin(cls) is Union:
return min(distance(arg, subclass) for arg in cls.__args__)
mro = type.mro(subclass) if isinstance(subclass, type) else subclass.mro()
return mro.index(cls if cls in mro else object)
Expand Down Expand Up @@ -186,6 +196,7 @@ def instances(self, *args) -> bool:
class multimethod(dict):
"""A callable directed acyclic graph of methods."""

__name__: str
pending: set
generics: list[set]

Expand Down Expand Up @@ -218,7 +229,7 @@ def register(self, *args) -> Callable:
"""
if len(args) == 1 and hasattr(args[0], '__annotations__'):
multimethod.__init__(self, *args)
return self if self.__name__ == args[0].__name__ else args[0] # type: ignore
return self if self.__name__ == args[0].__name__ else args[0]
return lambda func: self.__setitem__(args, func) or func

def __get__(self, instance, owner):
Expand Down Expand Up @@ -276,7 +287,7 @@ def select(self, types: tuple, keys: set[signature]) -> Callable:
funcs = {self[key] for key in keys}
if len(funcs) == 1:
return funcs.pop()
raise DispatchError(f"{self.__name__}: {len(keys)} methods found", types, keys) # type: ignore
raise DispatchError(f"{self.__name__}: {len(keys)} methods found", types, keys)

def __missing__(self, types: tuple) -> Callable:
"""Find and cache the next applicable method of given types."""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_roshambo():


# methods
class cls(object):
class cls:
method = multidispatch(lambda self, other: None)

@method.register(Iterable, object)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_subscripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ def test_final():
d = {'': 0}
assert isinstance(d, subtype(Mapping[str, int]))
assert isinstance(d.keys(), tp)


def test_args():
tp = type('', (), {'__args__': None})
assert subtype(tp) is tp
assert not issubclass(tp, subtype(list[int]))

0 comments on commit bbccb81

Please sign in to comment.