Skip to content

Commit

Permalink
Add Stream node constructor for sub-classing #442
Browse files Browse the repository at this point in the history
  • Loading branch information
florentbr committed Dec 10, 2021
1 parent 7453431 commit f9d993c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
20 changes: 20 additions & 0 deletions streamz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def __str__(self):

class APIRegisterMixin(object):

def _new_node(self, cls, args, kwargs):
""" Constructor for downstream nodes.
Examples
--------
To provide inheritance through nodes :
>>> class MyStream(Stream):
>>>
>>> def _new_node(self, cls, args, kwargs):
>>> if not issubclass(cls, MyStream):
>>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
>>> return cls(*args, **kwargs)
"""
return cls(*args, **kwargs)

@classmethod
def register_api(cls, modifier=identity, attribute_name=None):
""" Add callable to Stream API
Expand Down Expand Up @@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None):
def _(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if identity is not staticmethod and args:
self = args[0]
if isinstance(self, APIRegisterMixin):
return self._new_node(func, args, kwargs)
return func(*args, **kwargs)
name = attribute_name if attribute_name else func.__name__
setattr(cls, name, modifier(wrapped))
Expand Down
29 changes: 29 additions & 0 deletions streamz/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,35 @@ class foo(NewStream):
assert not hasattr(Stream(), 'foo')


def test_subclass_node():

def add(x) : return x + 1

class MyStream(Stream):
def _new_node(self, cls, args, kwargs):
if not issubclass(cls, MyStream):
cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
return cls(*args, **kwargs)

@MyStream.register_api()
class foo(sz.sinks.sink):
pass

stream = MyStream()
lst = list()

node = stream.map(add)
assert isinstance(node, sz.core.map)
assert isinstance(node, MyStream)

node = node.foo(lst.append)
assert isinstance(node, sz.sinks.sink)
assert isinstance(node, MyStream)

stream.emit(100)
assert lst == [ 101 ]


@gen_test()
def test_latest():
source = Stream(asynchronous=True)
Expand Down

0 comments on commit f9d993c

Please sign in to comment.