diff --git a/rapidz/core.py b/rapidz/core.py index 63ccc6b..a52a03e 100644 --- a/rapidz/core.py +++ b/rapidz/core.py @@ -1187,6 +1187,7 @@ def __init__(self, upstream, history=None, key=identity, **kwargs): from zict import LRU self.seen = LRU(history, self.seen) + # TODO: pull this out from history self.non_hash_seen = deque(maxlen=history) Stream.__init__(self, upstream, **kwargs) diff --git a/rapidz/parallel.py b/rapidz/parallel.py index 581182c..7f930ae 100644 --- a/rapidz/parallel.py +++ b/rapidz/parallel.py @@ -340,6 +340,11 @@ def update(self, x, who=None): tup = tuple(self.last) client = self.default_client() l = [] + # we only want to fall back on prior data if it is not the + # incoming data + # It is fine to not emit if the incoming data is bad, but in + # serial mode the bad data would have never gotten to the node + # so we need to have the buffered data only be good data for t, up in szip(tup, self.upstreams): if up == who: a = t @@ -362,7 +367,7 @@ class delay(ParallelStream, core.delay): class latest(ParallelStream, core.latest): pass - +# TODO: needs to be filter proofed @args_kwargs @ParallelStream.register_api() class partition(ParallelStream, core.partition): @@ -375,6 +380,7 @@ class rate_limit(ParallelStream, core.rate_limit): pass +# TODO: needs to be filter proofed @args_kwargs @ParallelStream.register_api() class sliding_window(ParallelStream, core.sliding_window): @@ -392,7 +398,7 @@ class timed_window(ParallelStream, core.timed_window): class union(ParallelStream, core.union): pass - +# TODO: needs to be filter proofed @args_kwargs @ParallelStream.register_api() class zip(ParallelStream): @@ -467,3 +473,42 @@ class filenames(ParallelStream, sources.filenames): @ParallelStream.register_api(staticmethod) class from_textfile(ParallelStream, sources.from_textfile): pass + + +def is_unique(x, past): + if x in past: + return NULL_COMPUTE + return x + +@args_kwargs +@ParallelStream.register_api() +class unique(ParallelStream): + """ Avoid sending through repeated elements + + This deduplicates a stream so that only new elements pass through. + You can control how much of a history is stored with the ``history=`` + parameter. For example setting ``history=1`` avoids sending through + elements when one is repeated right after the other. + + Examples + -------- + >>> source = Stream() + >>> source.unique(history=1).sink(print) + >>> for x in [1, 1, 2, 2, 2, 1, 3]: + ... source.emit(x) + 1 + 2 + 1 + 3 + """ + + def __init__(self, upstream, history=None, **kwargs): + self.history = history + self.past = [] + ParallelStream.__init__(self, upstream, **kwargs) + + def update(self, x, who=None): + client = self.default_client() + ret = client.submit(is_unique, x, self.past) + self.past.append(ret) + return self._emit(ret) diff --git a/rapidz/tests/test_parallel_filter.py b/rapidz/tests/test_parallel_filter.py index 37a6280..69570cf 100644 --- a/rapidz/tests/test_parallel_filter.py +++ b/rapidz/tests/test_parallel_filter.py @@ -12,7 +12,9 @@ gen_test = pytest.mark.gen_test -test_params = ["thread", thread_default_client] +test_params = ["thread", + thread_default_client + ] @pytest.mark.parametrize("backend", test_params) @@ -32,6 +34,7 @@ def test_filter_combine_latest(backend): yield source.emit(i) assert L == LL + s.default_client().shutdown() @pytest.mark.parametrize("backend", test_params) @@ -51,6 +54,7 @@ def test_filter_combine_latest_odd(backend): yield source.emit(i) assert L == LL + s.default_client().shutdown() @pytest.mark.parametrize("backend", test_params) @@ -71,3 +75,46 @@ def test_filter_combine_latest_emit_on(backend): yield source.emit(i) assert L == LL + s.default_client().shutdown() + + +@pytest.mark.parametrize("backend", test_params) +@gen_test() +def test_filter_combine_latest_triple(backend): + source = Stream(asynchronous=True) + + s = scatter(source, backend=backend) + futures = s.filter(lambda x: x % 3 == 1).combine_latest(s) + L = futures.gather().sink_to_list() + + presents = source.filter(lambda x: x % 3 == 1).combine_latest(source) + + LL = presents.sink_to_list() + + for i in range(10): + yield source.emit(i) + + assert L == LL + s.default_client().shutdown() + + +@pytest.mark.parametrize("backend", test_params) +@gen_test() +def test_unique(backend): + source = Stream(asynchronous=True) + + s = scatter(source, backend=backend) + futures = s.unique() + L = futures.gather().sink_to_list() + + presents = source.unique() + + LL = presents.sink_to_list() + + for i in range(10): + if i % 2 == 1: + i = i - 1 + yield source.emit(i) + + assert L == LL + s.default_client().shutdown()