diff --git a/streamz/core.py b/streamz/core.py index 4f541458..ee8b3fa0 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -28,6 +28,16 @@ logger = logging.getLogger(__name__) +class ClearMSG: + pass + + +class IgnoreMSG: + pass + + +# perhaps StopMSG? + def identity(x): return x @@ -277,7 +287,7 @@ def scatter(self, **kwargs): return scatter(self, **kwargs) def remove(self, predicate): - """ Only pass through elements for which the predicate returns False """ + """ Only pass through elements for which the predicate returns False""" return self.filter(lambda x: not predicate(x)) @property @@ -399,11 +409,13 @@ def __init__(self, upstream, func, *args, **kwargs): _global_sinks.add(self) def update(self, x, who=None): - result = self.func(x, *self.args, **self.kwargs) - if gen.isawaitable(result): - return result - else: - return [] + # only if not a clear msg, which is ignored + if not isinstance(x, ClearMSG) and not isinstance(x, IgnoreMSG): + result = self.func(x, *self.args, **self.kwargs) + if gen.isawaitable(result): + return result + else: + return [] @Stream.register_api() @@ -440,6 +452,9 @@ def __init__(self, upstream, func, *args, **kwargs): Stream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): + if isinstance(x, IgnoreMSG) or isinstance(x, ClearMSG): + return self.emit(x) + result = self.func(x, *self.args, **self.kwargs) return self._emit(result) @@ -516,6 +531,7 @@ class accumulate(Stream): def __init__(self, upstream, func, start=no_default, returns_state=False, **kwargs): + self.start = start self.func = func self.kwargs = kwargs self.state = start @@ -525,6 +541,15 @@ def __init__(self, upstream, func, start=no_default, returns_state=False, Stream.__init__(self, upstream, stream_name=stream_name) def update(self, x, who=None): + # if this is a request to clear, reset state + if isinstance(x, ClearMSG): + self.state = self.start + # and pass it through + return self.emit(x) + + if isinstance(x, IgnoreMSG): + return self.emit(x) + if self.state is no_default: self.state = x return self._emit(x) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index ec3d15e5..d3261d17 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -931,3 +931,28 @@ def slow_write(x): if sys.version_info >= (3, 5): from streamz.tests.py3_test_core import * # noqa + + +def test_clear(): + s = Stream() + # increment + def acc1(x1, x2): + return x1 + x2 + from streamz.core import ClearMSG, IgnoreMSG + + clear_msg = ClearMSG() + ignore_msg = IgnoreMSG() + + sout = s.accumulate(acc1) + Lout = sout.sink_to_list() + sout2 = sout.map(lambda x : x + 1) + Lout2 = sout2.sink_to_list() + + s.emit(1) + s.emit(2) + s.emit(ignore_msg) + s.emit(3) + s.emit(clear_msg) + s.emit(3) + assert Lout == [1, 3, 6, 3] + assert Lout2 == [2, 4, 7, 4]