Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for nodes #253

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions anytree/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
* :any:`ZigZagGroupIter`: iterate over tree using level-order strategy returning group for every level
"""

from .abstractiter import AbstractIter # noqa
from .levelordergroupiter import LevelOrderGroupIter # noqa
from .levelorderiter import LevelOrderIter # noqa
from .postorderiter import PostOrderIter # noqa
from .preorderiter import PreOrderIter # noqa
from .zigzaggroupiter import ZigZagGroupIter # noqa
from .abstractiter import AbstractIter as AbstractIter # noqa
CoolCat467 marked this conversation as resolved.
Show resolved Hide resolved
from .levelordergroupiter import LevelOrderGroupIter as LevelOrderGroupIter # noqa
from .levelorderiter import LevelOrderIter as LevelOrderIter # noqa
from .postorderiter import PostOrderIter as PostOrderIter # noqa
from .preorderiter import PreOrderIter as PreOrderIter # noqa
from .zigzaggroupiter import ZigZagGroupIter as ZigZagGroupIter # noqa
45 changes: 34 additions & 11 deletions anytree/iterators/abstractiter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Generic, TypeVar

import six

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from typing_extensions import Self

from ..node.lightnodemixin import LightNodeMixin
from ..node.nodemixin import NodeMixin


NodeT = TypeVar("NodeT", bound=NodeMixin[Any] | LightNodeMixin[Any], covariant=True)


class AbstractIter(six.Iterator):
class AbstractIter(six.Iterator, Generic[NodeT]):
# pylint: disable=R0205
"""
Iterate over tree starting at `node`.
Expand All @@ -14,14 +29,20 @@ class AbstractIter(six.Iterator):
maxlevel (int): maximum descending in the node hierarchy.
"""

def __init__(self, node, filter_=None, stop=None, maxlevel=None):
def __init__(
self,
node: NodeT,
filter_: Callable[[NodeT], bool] | None = None,
stop: Callable[[NodeT], bool] | None = None,
maxlevel: int | None = None,
) -> None:
self.node = node
self.filter_ = filter_
self.stop = stop
self.maxlevel = maxlevel
self.__iter = None
self.__iter: Iterator[NodeT] | None = None

def __init(self):
def __init(self) -> Iterator[NodeT]:
node = self.node
maxlevel = self.maxlevel
filter_ = self.filter_ or AbstractIter.__default_filter
Expand All @@ -30,31 +51,33 @@ def __init(self):
return self._iter(children, filter_, stop, maxlevel)

@staticmethod
def __default_filter(node):
def __default_filter(node: NodeT) -> bool:
# pylint: disable=W0613
return True

@staticmethod
def __default_stop(node):
def __default_stop(node: NodeT) -> bool:
# pylint: disable=W0613
return False

def __iter__(self):
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> NodeT:
if self.__iter is None:
self.__iter = self.__init()
return next(self.__iter)

@staticmethod
def _iter(children, filter_, stop, maxlevel):
def _iter(
children: Iterable[NodeT], filter_: Callable[[NodeT], bool], stop: Callable[[NodeT], bool], maxlevel: int | None
) -> Iterator[NodeT]:
raise NotImplementedError() # pragma: no cover

@staticmethod
def _abort_at_level(level, maxlevel):
def _abort_at_level(level: int, maxlevel: int | None) -> bool:
return maxlevel is not None and level > maxlevel

@staticmethod
def _get_children(children, stop):
def _get_children(children: Iterable[NodeT], stop: Callable[[NodeT], bool]) -> list[Any]:
return [child for child in children if not stop(child)]
16 changes: 8 additions & 8 deletions anytree/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
* :any:`LightNodeMixin`: A :any:`NodeMixin` using slots.
"""

from .anynode import AnyNode # noqa
from .exceptions import LoopError # noqa
from .exceptions import TreeError # noqa
from .lightnodemixin import LightNodeMixin # noqa
from .node import Node # noqa
from .nodemixin import NodeMixin # noqa
from .symlinknode import SymlinkNode # noqa
from .symlinknodemixin import SymlinkNodeMixin # noqa
from .anynode import AnyNode as AnyNode # noqa
from .exceptions import LoopError as LoopError # noqa
from .exceptions import TreeError as TreeError # noqa
from .lightnodemixin import LightNodeMixin as LightNodeMixin # noqa
from .node import Node as Node # noqa
from .nodemixin import NodeMixin as NodeMixin # noqa
from .symlinknode import SymlinkNode as SymlinkNode # noqa
from .symlinknodemixin import SymlinkNodeMixin as SymlinkNodeMixin # noqa
14 changes: 10 additions & 4 deletions anytree/node/anynode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from .nodemixin import NodeMixin
from .util import _repr

if TYPE_CHECKING:
from collections.abc import Iterable

class AnyNode(NodeMixin):

class AnyNode(NodeMixin[AnyNode]):
"""
A generic tree node with any `kwargs`.

Expand Down Expand Up @@ -92,12 +99,11 @@ class AnyNode(NodeMixin):
... ])
"""

def __init__(self, parent=None, children=None, **kwargs):

def __init__(self, parent: AnyNode | None = None, children: Iterable[AnyNode] | None = None, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
self.parent = parent
if children:
self.children = children

def __repr__(self):
def __repr__(self) -> str:
return _repr(self)
Loading