From 8c781be7a5a41af14da8b3b012adad48d9ec4215 Mon Sep 17 00:00:00 2001 From: Julien Lhermitte Date: Mon, 23 Oct 2017 16:43:14 -0400 Subject: [PATCH 1/2] First adding for a clear on accumulator --- streamz/core.py | 33 +++++++++++++++++++++++++++------ streamz/tests/test_core.py | 20 ++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/streamz/core.py b/streamz/core.py index 4f541458..6f403012 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -28,6 +28,18 @@ logger = logging.getLogger(__name__) +class ClearMSG: + pass + + +class IgnoreMSG: + pass + + +class StopMSG: + pass + + def identity(x): return x @@ -277,7 +289,7 @@ def scatter(self, **kwargs): return scatter(self, **kwargs) def remove(self, predicate): - """ Only pass through elements for which the predicate returns False """ + """ Only pass through elements for which the predicate returns False""" return self.filter(lambda x: not predicate(x)) @property @@ -399,11 +411,13 @@ def __init__(self, upstream, func, *args, **kwargs): _global_sinks.add(self) def update(self, x, who=None): - result = self.func(x, *self.args, **self.kwargs) - if gen.isawaitable(result): - return result - else: - return [] + # only if not a clear msg, which is ignored + if not isinstance(x, ClearMSG): + result = self.func(x, *self.args, **self.kwargs) + if gen.isawaitable(result): + return result + else: + return [] @Stream.register_api() @@ -516,6 +530,7 @@ class accumulate(Stream): def __init__(self, upstream, func, start=no_default, returns_state=False, **kwargs): + self.start = start self.func = func self.kwargs = kwargs self.state = start @@ -525,6 +540,12 @@ def __init__(self, upstream, func, start=no_default, returns_state=False, Stream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): + # if this is a request to clear, reset state + if isinstance(x, ClearMSG): + self.state = self.start + # and pass it through + return self.emit(x) + if self.state is no_default: self.state = x return self._emit(x) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index ec3d15e5..58aab7c7 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -931,3 +931,23 @@ def slow_write(x): if sys.version_info >= (3, 5): from streamz.tests.py3_test_core import * # noqa + + +def test_clear(): + s = Stream() + # increment + def acc1(x1, x2): + return x1 + x2 + from streamz.core import ClearMSG + + cc = ClearMSG() + + sout = s.accumulate(acc1) + + Lout = sout.sink_to_list() + + s.emit(1) + s.emit(2) + s.emit(cc) + s.emit(3) + assert Lout == [1, 3, 3] From 7d7905211427bb89f432835af364b873d083520c Mon Sep 17 00:00:00 2001 From: Julien Lhermitte Date: Mon, 23 Oct 2017 16:59:31 -0400 Subject: [PATCH 2/2] added ignore --- streamz/core.py | 12 ++++++++---- streamz/tests/test_core.py | 15 ++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/streamz/core.py b/streamz/core.py index 6f403012..ee8b3fa0 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -36,9 +36,7 @@ class IgnoreMSG: pass -class StopMSG: - pass - +# perhaps StopMSG? def identity(x): return x @@ -412,7 +410,7 @@ def __init__(self, upstream, func, *args, **kwargs): def update(self, x, who=None): # only if not a clear msg, which is ignored - if not isinstance(x, ClearMSG): + if not isinstance(x, ClearMSG) and not isinstance(x, IgnoreMSG): result = self.func(x, *self.args, **self.kwargs) if gen.isawaitable(result): return result @@ -454,6 +452,9 @@ def __init__(self, upstream, func, *args, **kwargs): Stream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): + if isinstance(x, IgnoreMSG) or isinstance(x, ClearMSG): + return self.emit(x) + result = self.func(x, *self.args, **self.kwargs) return self._emit(result) @@ -546,6 +547,9 @@ def update(self, x, who=None): # and pass it through return self.emit(x) + if isinstance(x, IgnoreMSG): + return self.emit(x) + if self.state is no_default: self.state = x return self._emit(x) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 58aab7c7..d3261d17 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -938,16 +938,21 @@ def test_clear(): # increment def acc1(x1, x2): return x1 + x2 - from streamz.core import ClearMSG + from streamz.core import ClearMSG, IgnoreMSG - cc = ClearMSG() + clear_msg = ClearMSG() + ignore_msg = IgnoreMSG() sout = s.accumulate(acc1) - Lout = sout.sink_to_list() + sout2 = sout.map(lambda x : x + 1) + Lout2 = sout2.sink_to_list() s.emit(1) s.emit(2) - s.emit(cc) + s.emit(ignore_msg) + s.emit(3) + s.emit(clear_msg) s.emit(3) - assert Lout == [1, 3, 3] + assert Lout == [1, 3, 6, 3] + assert Lout2 == [2, 4, 7, 4]