Skip to content

Commit

Permalink
Backport improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Mauko Quiroga committed Sep 16, 2021
1 parent 31ffaf4 commit bfe9e56
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 66 deletions.
126 changes: 96 additions & 30 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from __future__ import annotations

import enum
from typing import Union
from typing import Any, TypeVar, Union

import numpy

from openfisca_core.types import ArrayType, Choosable
from openfisca_core.types import ArrayType, Encodable

from .. import indexed_enums as enums
from .enum_array import EnumArray

#: Type of any encodable array.
Encodable = Union[
EnumArray,
ArrayType[Choosable],
ArrayType[bytes],
ArrayType[int],
ArrayType[str],
]
T = TypeVar("T", Encodable, bytes, int, str)


class Enum(enum.Enum):
Expand All @@ -26,39 +19,78 @@ class Enum(enum.Enum):
Whose items have an :obj:`int` index. This is useful and performant when
running simulations on large populations.
Attributes:
index (:obj:`int`): The ``index`` of the :class:`.Enum` member.
name (:obj:`str`): The ``name`` of the :class:`.Enum` member.
value: The ``value`` of the :class:`.Enum` member.
Examples:
>>> class Housing(Enum):
... owner = "Owner"
... tenant = "Tenant"
... free_lodger = "Free lodger"
... homeless = "Homeless"
... OWNER = "Owner"
... TENANT = "Tenant"
... FREE_LODGER = "Free lodger"
... HOMELESS = "Homeless"
>>> repr(Housing)
"<enum 'Housing'>"
>>> repr(Housing.TENANT)
'<Housing.TENANT(Tenant)>'
>>> str(Housing.TENANT)
'Housing.TENANT'
>>> dict([(Housing.TENANT, Housing.TENANT.value)])
{<Housing.TENANT(Tenant)>: 'Tenant'}
>>> tuple(Housing)
(<Housing.OWNER(Owner)>, <Housing.TENANT(Tenant)>, ...)
>>> Housing
<enum 'Housing'>
>>> Housing["TENANT"]
<Housing.TENANT(Tenant)>
>>> list(Housing)
[<Housing.owner: 'Owner'>, ...]
>>> Housing("Tenant")
<Housing.TENANT(Tenant)>
>>> Housing.TENANT in Housing
True
>>> len(Housing)
4
>>> Housing.tenant
<Housing.tenant: 'Tenant'>
>>> Housing.TENANT == Housing.TENANT
True
>>> Housing.TENANT != Housing.TENANT
False
>>> Housing.TENANT > Housing.TENANT
False
>>> Housing["tenant"]
<Housing.tenant: 'Tenant'>
>>> Housing.TENANT < Housing.TENANT
False
>>> Housing.tenant.index
>>> Housing.TENANT >= Housing.TENANT
True
>>> Housing.TENANT <= Housing.TENANT
True
>>> Housing.TENANT.index
1
>>> Housing.tenant.name
'tenant'
>>> Housing.TENANT.name
'TENANT'
>>> Housing.tenant.value
>>> Housing.TENANT.value
'Tenant'
"""

index: int
name: str
value: Any

def __init__(self, name: str) -> None:
""" Tweaks :class:`~enum.Enum` to add an index to each enum item.
Expand All @@ -81,18 +113,52 @@ def __init__(self, name: str) -> None:
>>> MyEnum.bar.index
1
>>> array = numpy.array([[1, 2], [3, 4]])
>>> array[MyEnum.bar.index]
array([3, 4])
"""

self.index = len(self._member_names_)

#: Bypass the slow :meth:`~enum.Enum.__eq__`.
def __repr__(self) -> str:
return f"<{self.__class__.__name__}.{self.name}({self.value})>"

def __str__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"

def __lt__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented

return self.index < other.index

def __le__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented

return self.index <= other.index

def __gt__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented

return self.index > other.index

def __ge__(self, other: object) -> bool:
if not isinstance(other, Enum):
return NotImplemented

return self.index >= other.index

__eq__ = object.__eq__
"""Bypass the slow :meth:`~enum.Enum.__eq__`."""

#: :meth:`.__hash__` must also be defined as so to stay hashable.
__hash__ = object.__hash__
""":meth:`.__hash__` must also be defined as so to stay hashable."""

@classmethod
def encode(cls, array: Encodable) -> EnumArray:
def encode(cls, array: Union[EnumArray, ArrayType[T]]) -> EnumArray:
"""Encodes an encodable array into an :obj:`.EnumArray`.
Args:
Expand All @@ -112,7 +178,7 @@ def encode(cls, array: Encodable) -> EnumArray:
>>> array = numpy.array([1])
>>> enum_array = EnumArray(array, MyEnum)
>>> MyEnum.encode(enum_array)
EnumArray([<MyEnum.bar: b'bar'>])
<EnumArray([<MyEnum.bar(b'bar')>])>
# ArrayTipe[Enum]
Expand Down
63 changes: 43 additions & 20 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy

from openfisca_core.types import ArrayLike, ArrayType, Choosable
from openfisca_core.types import ArrayLike, ArrayType, Encodable


class EnumArray(numpy.ndarray):
Expand All @@ -23,21 +23,46 @@ class EnumArray(numpy.ndarray):
>>> from openfisca_core.variables import Variable
>>> class Housing(Enum):
... owner = "Owner"
... tenant = "Tenant"
... free_lodger = "Free lodger"
... homeless = "Homeless"
... Owner = "Owner"
... Tenant = "Tenant"
... FreeLodger = "Free lodger"
... Homeless = "Homeless"
>>> array = numpy.array([1])
>>> EnumArray(array, Housing)
EnumArray([<Housing.tenant: 'Tenant'>])
>>> enum_array = EnumArray(array, Housing)
>>> repr(EnumArray)
"<class 'openfisca_core.indexed_enums.enum_array.EnumArray'>"
>>> repr(enum_array)
'<EnumArray([<Housing.Tenant(Tenant)>])>'
>>> str(enum_array)
"['Tenant']"
>>> tuple(enum_array)
(1,)
>>> enum_array[0]
1
>>> enum_array[0] in enum_array
True
>>> len(enum_array)
1
>>> enum_array = EnumArray(list(Housing), Housing)
>>> enum_array[Housing.Tenant.index]
<Housing.Tenant(Tenant)>
>>> class OccupancyStatus(Variable):
... value_type = Enum
... possible_values = Housing
>>> EnumArray(array, OccupancyStatus.possible_values)
EnumArray([<Housing.tenant: 'Tenant'>])
<EnumArray([<Housing.Tenant(Tenant)>])>
.. _Subclassing ndarray:
https://numpy.org/doc/stable/user/basics.subclassing.html
Expand All @@ -47,22 +72,26 @@ class EnumArray(numpy.ndarray):
def __new__(
cls,
input_array: ArrayType[int],
possible_values: Optional[Type[Choosable]] = None,
possible_values: Optional[Type[Encodable]] = None,
) -> EnumArray:
"""See comment above"""
"""See comment above."""

obj = numpy.asarray(input_array).view(cls)
obj.possible_values = possible_values
return obj

def __array_finalize__(self, obj: Optional[ArrayType[int]]) -> None:
"""See comment above…"""

if obj is None:
return

self.possible_values = getattr(obj, "possible_values", None)

def __repr__(self) -> str:
return f"<{self.__class__.__name__}({str(self.decode())})>"

def __str__(self) -> str:
return str(self.decode_to_str())

def __eq__(self, other: Any) -> Union[ArrayType[bool], bool]:
"""Compare equality with the item index.
Expand Down Expand Up @@ -163,7 +192,7 @@ def _forbidden_operation(self, other: Any) -> NoReturn:
__and__ = _forbidden_operation
__or__ = _forbidden_operation

def decode(self) -> ArrayLike[Choosable]:
def decode(self) -> ArrayLike[Encodable]:
"""Decodes itself to a normal array.
Returns:
Expand All @@ -179,7 +208,7 @@ def decode(self) -> ArrayLike[Choosable]:
>>> array = numpy.array([1])
>>> enum_array = EnumArray(array, MyEnum)
>>> enum_array.decode()
array([<MyEnum.bar: b'bar'>]...)
array([<MyEnum.bar(b'bar')>], dtype=object)
"""

Expand Down Expand Up @@ -212,9 +241,3 @@ def decode_to_str(self) -> ArrayType[str]:
[self == item.index for item in self.possible_values],
[item.name for item in self.possible_values],
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({str(self.decode())})"

def __str__(self) -> str:
return str(self.decode_to_str())
Empty file.
12 changes: 0 additions & 12 deletions openfisca_core/indexed_enums/tests/test_enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,3 @@ def test_enum_array_any_other_operation(enum_array):

with pytest.raises(TypeError, match = "Forbidden operation."):
enum_array * 1


def test_enum_array___repr__(enum_array):
"""Enum arrays have a custom debugging representation."""

assert repr(enum_array) == "EnumArray([<MyEnum.bar: b'bar'>])"


def test_enum_array___str__(enum_array):
"""Enum arrays have a custom end-user representation."""

assert str(enum_array) == "['bar']"
3 changes: 2 additions & 1 deletion openfisca_core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@

from .protocols import ( # noqa: F401
Buildable,
Choosable,
Descriptable,
Encodable,
Modelable,
Personifiable,
Representable,
Rolifiable,
Timeable,
)
2 changes: 1 addition & 1 deletion openfisca_core/types/callables/formulas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Callable

from ..protocols.timeable import Timeable
from ..protocols._aggregatable import Aggregatable
from ..protocols._instantizable import Instantizable
from ..protocols._timeable import Timeable
from ..data_types import ArrayType

ParamsType = Callable[[Timeable], Instantizable]
Expand Down
3 changes: 2 additions & 1 deletion openfisca_core/types/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
# See: https://www.python.org/dev/peps/pep-0008/#imports

from .buildable import Buildable # noqa: F401
from .choosable import Choosable # noqa: F401
from .descriptable import Descriptable # noqa: F401
from .encodable import Encodable # noqa: F401
from .modelable import Modelable # noqa: F401
from .personifiable import Personifiable # noqa: F401
from .representable import Representable # noqa: F401
from .rolifiable import Rolifiable # noqa: F401
from .timeable import Timeable # noqa: F401
2 changes: 2 additions & 0 deletions openfisca_core/types/protocols/buildable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def __init__(self, builder: RT, buildee: Type[ET]) -> None:
@abc.abstractmethod
def __call__(self, items: Iterable[EL]) -> Sequence[ET]:
"""A concrete builder implements :meth:`.__call__`."""

...

@abc.abstractmethod
def build(self, item: EL) -> ET:
"""A concrete builder implements :meth:`.build`."""

...
23 changes: 23 additions & 0 deletions openfisca_core/types/protocols/encodable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import abc
from typing import Any

from typing_extensions import Protocol


class Encodable(Protocol):
"""Base type for any model implementing a literal list of choices.
Type-checking against abstractions rather than implementations helps in
(a) decoupling the codebse, thanks to structural subtyping, and
(b) documenting/enforcing the blueprints of the different OpenFisca models.
.. versionadded:: 35.8.0
"""

@classmethod
@abc.abstractmethod
def encode(cls, array: Any) -> Any:
"""A concrete encodable model implements :meth:`.encode`."""

...
Loading

0 comments on commit bfe9e56

Please sign in to comment.