Skip to content

Commit

Permalink
add unique
Browse files Browse the repository at this point in the history
  • Loading branch information
CJ-Wright committed Dec 15, 2018
1 parent 231710e commit 7869276
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 3 deletions.
1 change: 1 addition & 0 deletions rapidz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,7 @@ def __init__(self, upstream, history=None, key=identity, **kwargs):
from zict import LRU

self.seen = LRU(history, self.seen)
# TODO: pull this out from history
self.non_hash_seen = deque(maxlen=history)

Stream.__init__(self, upstream, **kwargs)
Expand Down
49 changes: 47 additions & 2 deletions rapidz/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ def update(self, x, who=None):
tup = tuple(self.last)
client = self.default_client()
l = []
# we only want to fall back on prior data if it is not the
# incoming data
# It is fine to not emit if the incoming data is bad, but in
# serial mode the bad data would have never gotten to the node
# so we need to have the buffered data only be good data
for t, up in szip(tup, self.upstreams):
if up == who:
a = t
Expand All @@ -362,7 +367,7 @@ class delay(ParallelStream, core.delay):
class latest(ParallelStream, core.latest):
pass


# TODO: needs to be filter proofed
@args_kwargs
@ParallelStream.register_api()
class partition(ParallelStream, core.partition):
Expand All @@ -375,6 +380,7 @@ class rate_limit(ParallelStream, core.rate_limit):
pass


# TODO: needs to be filter proofed
@args_kwargs
@ParallelStream.register_api()
class sliding_window(ParallelStream, core.sliding_window):
Expand All @@ -392,7 +398,7 @@ class timed_window(ParallelStream, core.timed_window):
class union(ParallelStream, core.union):
pass


# TODO: needs to be filter proofed
@args_kwargs
@ParallelStream.register_api()
class zip(ParallelStream):
Expand Down Expand Up @@ -467,3 +473,42 @@ class filenames(ParallelStream, sources.filenames):
@ParallelStream.register_api(staticmethod)
class from_textfile(ParallelStream, sources.from_textfile):
pass


def is_unique(x, past):
if x in past:
return NULL_COMPUTE
return x

@args_kwargs
@ParallelStream.register_api()
class unique(ParallelStream):
""" Avoid sending through repeated elements
This deduplicates a stream so that only new elements pass through.
You can control how much of a history is stored with the ``history=``
parameter. For example setting ``history=1`` avoids sending through
elements when one is repeated right after the other.
Examples
--------
>>> source = Stream()
>>> source.unique(history=1).sink(print)
>>> for x in [1, 1, 2, 2, 2, 1, 3]:
... source.emit(x)
1
2
1
3
"""

def __init__(self, upstream, history=None, **kwargs):
self.history = history
self.past = []
ParallelStream.__init__(self, upstream, **kwargs)

def update(self, x, who=None):
client = self.default_client()
ret = client.submit(is_unique, x, self.past)
self.past.append(ret)
return self._emit(ret)
49 changes: 48 additions & 1 deletion rapidz/tests/test_parallel_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

gen_test = pytest.mark.gen_test

test_params = ["thread", thread_default_client]
test_params = ["thread",
thread_default_client
]


@pytest.mark.parametrize("backend", test_params)
Expand All @@ -32,6 +34,7 @@ def test_filter_combine_latest(backend):
yield source.emit(i)

assert L == LL
s.default_client().shutdown()


@pytest.mark.parametrize("backend", test_params)
Expand All @@ -51,6 +54,7 @@ def test_filter_combine_latest_odd(backend):
yield source.emit(i)

assert L == LL
s.default_client().shutdown()


@pytest.mark.parametrize("backend", test_params)
Expand All @@ -71,3 +75,46 @@ def test_filter_combine_latest_emit_on(backend):
yield source.emit(i)

assert L == LL
s.default_client().shutdown()


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_filter_combine_latest_triple(backend):
source = Stream(asynchronous=True)

s = scatter(source, backend=backend)
futures = s.filter(lambda x: x % 3 == 1).combine_latest(s)
L = futures.gather().sink_to_list()

presents = source.filter(lambda x: x % 3 == 1).combine_latest(source)

LL = presents.sink_to_list()

for i in range(10):
yield source.emit(i)

assert L == LL
s.default_client().shutdown()


@pytest.mark.parametrize("backend", test_params)
@gen_test()
def test_unique(backend):
source = Stream(asynchronous=True)

s = scatter(source, backend=backend)
futures = s.unique()
L = futures.gather().sink_to_list()

presents = source.unique()

LL = presents.sink_to_list()

for i in range(10):
if i % 2 == 1:
i = i - 1
yield source.emit(i)

assert L == LL
s.default_client().shutdown()

0 comments on commit 7869276

Please sign in to comment.