Skip to content

Commit

Permalink
Merge pull request #199 from nucleic/default_atomdict
Browse files Browse the repository at this point in the history
defaultatomdict container and related member
  • Loading branch information
MatthieuDartiailh authored May 5, 2023
2 parents 5767ab1 + 96ae3a3 commit 392f64e
Show file tree
Hide file tree
Showing 17 changed files with 734 additions and 11 deletions.
5 changes: 4 additions & 1 deletion atom/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
atomlist,
atomref,
atomset,
defaultatomdict,
)
from .coerced import Coerced
from .containerlist import ContainerList
from .delegator import Delegator
from .dict import Dict
from .dict import DefaultDict, Dict
from .enum import Enum
from .event import Event
from .instance import ForwardInstance, Instance
Expand Down Expand Up @@ -85,13 +86,15 @@
"Validate",
"atomclist",
"atomdict",
"defaultatomdict",
"atomlist",
"atomref",
"atomset",
"Coerced",
"ContainerList",
"Delegator",
"Dict",
"DefaultDict",
"Enum",
"Event",
"ForwardInstance",
Expand Down
3 changes: 3 additions & 0 deletions atom/catom.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ class atomlist(List[T]): ...
class atomclist(atomlist[T]): ...
class atomset(Set[T]): ...
class atomdict(Dict[KT, VT]): ...
class defaultatomdict(Dict[KT, VT]): ...

A = TypeVar("A", bound=CAtom)

Expand All @@ -446,6 +447,7 @@ class DefaultValue(IntEnum):
CallObject_ObjectName = ...
Delegate = ...
Dict = ...
DefaultDict = ...
List = ...
MemberMethod_Object = ...
NonOptional = ...
Expand Down Expand Up @@ -524,6 +526,7 @@ class Validate(IntEnum):
ContainerList = ...
Delegate = ...
Dict = ...
DefaultDict = ...
Enum = ...
Float = ...
FloatPromote = ...
Expand Down
139 changes: 139 additions & 0 deletions atom/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# The full license is in the file LICENSE, distributed with this software.
# --------------------------------------------------------------------------------------
from collections import defaultdict

from .catom import DefaultValue, Member, Validate
from .instance import Instance
from .typing_utils import extract_types, is_optional
Expand Down Expand Up @@ -90,3 +92,140 @@ def clone(self):
mode, _ = self.validate_mode
clone.set_validate_mode(mode, (key_clone, value_clone))
return clone


class _DefaultWrapper:
__slots__ = ("wrapped",)

def __init__(self, wrapped):
self.wrapped = wrapped

def __call__(self, atom):
return self.wrapped()

def __repr__(self):
return repr(self.wrapped)


class DefaultDict(Member):
"""A value of type `dict` implementing __missing__"""

__slots__ = ()

def __init__(self, key=None, value=None, default=None, *, missing=None):
"""Initialize a DefaultDict.
Parameters
----------
key : Member, type, tuple of types, or None, optional
A member to use for validating the types of keys allowed in
the dict. This can also be a type or a tuple of types, which
will be wrapped with an Instance member. If this is not
given, no key validation is performed.
value : Member, type, tuple of types, or None, optional
A member to use for validating the types of values allowed
in the dict. This can also be a type or a tuple of types,
which will be wrapped with an Instance member. If this is
not given, no value validation is performed.
default : dict or None, optional
The default dict of items. A new copy of this dict will be
created for each atom instance.
missing : Callable[[], Any] or None, optional
Factory to build a default value for a missing key in the dictionary.
"""
self.set_default_value_mode(DefaultValue.DefaultDict, default)
if key is not None and not isinstance(key, Member):
opt, types = is_optional(extract_types(key))
key = Instance(types, optional=opt)
if value is not None and not isinstance(value, Member):
opt, types = is_optional(extract_types(value))
# Assume a default value can be created to avoid the need to specify a
# missing factory in simple case even for custom types.
value = Instance(types, optional=opt, args=())

if missing is not None:
if not callable(missing):
raise ValueError(
f"The missing argument expect a callable, got {missing}"
)
try:
missing()
except Exception as e:
raise ValueError(
"The missing argument expect a callable taking no argument. "
"Trying to call it with not argument failed with the chained "
"exception."
) from e
missing = _DefaultWrapper(missing)

if isinstance(default, defaultdict):
if missing is not None:
raise ValueError(
"Both a missing factory and a default value which is a default "
"dictionary were specified. When using a default dict as default "
"value missing should be omitted."
)
missing = _DefaultWrapper(default.default_factory)

if (
missing is None
and value is not None
and value.default_value_mode[0]
not in (DefaultValue.NoOp, DefaultValue.NonOptional)
):
missing = value.do_default_value

if missing is None:
raise ValueError(
"No missing value factory was specified and none could be "
"deduced from the value member."
)

self.set_validate_mode(Validate.DefaultDict, (key, value, missing))

def set_name(self, name):
"""Assign the name to this member.
This method is called by the Atom metaclass when a class is
created. This makes sure the name of the internal members are
also updated.
"""
super().set_name(name)
key, value, _ = self.validate_mode[1]
if key is not None:
key.set_name(name + "|key")
if value is not None:
value.set_name(name + "|value")

def set_index(self, index):
"""Assign the index to this member.
This method is called by the Atom metaclass when a class is
created. This makes sure the index of the internal members are
also updated.
"""
super().set_index(index)
key, value, _ = self.validate_mode[1]
if key is not None:
key.set_index(index)
if value is not None:
value.set_index(index)

def clone(self):
"""Create a clone of the member.
This will clone the internal dict key and value members if they exist.
"""
clone = super().clone()
mode, (key, value, missing) = self.validate_mode
key_clone = key.clone() if key is not None else None
value_clone = value.clone() if value is not None else None
clone.set_validate_mode(mode, (key_clone, value_clone, missing))
return clone
2 changes: 2 additions & 0 deletions atom/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,5 @@ class Dict(Member[TDict[KT, VT], TDict[KT, VT]]):
value: Member[VT, VT],
default: Optional[TDict[Any, Any]] = None,
) -> Dict[KT, VT]: ...

class DefaultDict(Dict[KT, VT]): ...
12 changes: 10 additions & 2 deletions atom/meta/annotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# The full license is in the file LICENSE, distributed with this software.
# --------------------------------------------------------------------------------------
import collections.abc
from collections import defaultdict
from typing import Any, ClassVar, MutableMapping, Type

from ..catom import Member
from ..dict import Dict as ADict
from ..dict import DefaultDict, Dict as ADict
from ..instance import Instance
from ..list import List as AList
from ..scalars import Bool, Bytes, Callable as ACallable, Float, Int, Str, Value
Expand All @@ -30,6 +31,7 @@
bytes: Bytes,
list: AList,
dict: ADict,
defaultdict: DefaultDict,
set: ASet,
tuple: ATuple,
collections.abc.Callable: ACallable,
Expand Down Expand Up @@ -60,7 +62,13 @@ def generate_member_from_type_or_generic(
elif len(types) == 1 and types[0] in _TYPE_TO_MEMBER:
t = types[0]
m_cls = _TYPE_TO_MEMBER[t]
if annotate_type_containers and t in (list, dict, set, tuple):
if annotate_type_containers and t in (
list,
dict,
collections.defaultdict,
set,
tuple,
):
# We can only validate homogeneous tuple so far so we ignore other cases
if t is tuple:
if (...) in parameters or len(set(parameters)) == 1:
Expand Down
Loading

0 comments on commit 392f64e

Please sign in to comment.