diff --git a/tests/test_tree.py b/tests/test_tree.py index fcb67ce..eb7cfd3 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -617,6 +617,25 @@ class SubTree(Tree): node = tree.create_node() self.assertTrue(isinstance(node, SubNode)) + def test_subclassing_extra_kwargs(self): + class SubNode(Node): + def __init__(self, some_argument=None, **kwargs): + self.some_property = some_argument + super().__init__(**kwargs) + + class SubTree(Tree): + node_class = SubNode + + tree = SubTree() + node = tree.create_node(some_argument="some_value") + self.assertTrue(isinstance(node, SubNode)) + self.assertEqual(node.some_property, "some_value") + + tree = Tree(node_class=SubNode) + node = tree.create_node(some_argument="some_value") + self.assertTrue(isinstance(node, SubNode)) + self.assertEqual(node.some_property, "some_value") + def test_shallow_copy_hermetic_pointers(self): # tree 1 # Hárry diff --git a/treelib/tree.py b/treelib/tree.py index 6ee18c9..42535fc 100644 --- a/treelib/tree.py +++ b/treelib/tree.py @@ -336,12 +336,12 @@ def contains(self, nid): """Check if the tree contains node of given id""" return True if nid in self._nodes else False - def create_node(self, tag=None, identifier=None, parent=None, data=None): + def create_node(self, tag=None, identifier=None, parent=None, data=None, **kwargs): """ Create a child node for given @parent node. If ``identifier`` is absent, a UUID will be generated automatically. """ - node = self.node_class(tag=tag, identifier=identifier, data=data) + node = self.node_class(tag=tag, identifier=identifier, data=data, **kwargs) self.add_node(node, parent) return node