Skip to content

Commit f6f6e14

Browse files
authored
Merge pull request #248 from CJ-Wright/add_upstream2
FIX: connect doesn't break zip now
2 parents cc62c7b + b68c430 commit f6f6e14

File tree

2 files changed

+119
-9
lines changed

2 files changed

+119
-9
lines changed

streamz/core.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,30 @@ def _inform_asynchronous(self, asynchronous):
181181
if downstream:
182182
downstream._inform_asynchronous(asynchronous)
183183

184+
def _add_upstream(self, upstream):
185+
"""Add upstream to current upstreams, this method is overridden for
186+
classes which handle stream specific buffers/caches"""
187+
if self.upstreams == [None]:
188+
self.upstreams[0] = upstream
189+
else:
190+
self.upstreams.append(upstream)
191+
192+
def _add_downstream(self, downstream):
193+
"""Add downstream to current downstreams"""
194+
self.downstreams.add(downstream)
195+
196+
def _remove_downstream(self, downstream):
197+
"""Remove downstream from current downstreams"""
198+
self.downstreams.remove(downstream)
199+
200+
def _remove_upstream(self, upstream):
201+
"""Remove upstream from current upstreams, this method is overridden for
202+
classes which handle stream specific buffers/caches"""
203+
if len(self.upstreams) == 1:
204+
self.upstreams[0] = [None]
205+
else:
206+
self.upstreams.remove(upstream)
207+
184208
@classmethod
185209
def register_api(cls, modifier=identity):
186210
""" Add callable to Stream API
@@ -352,12 +376,8 @@ def connect(self, downstream):
352376
downstream: Stream
353377
The downstream stream to connect to
354378
'''
355-
self.downstreams.add(downstream)
356-
357-
if downstream.upstreams == [None]:
358-
downstream.upstreams = [self]
359-
else:
360-
downstream.upstreams.append(self)
379+
self._add_downstream(downstream)
380+
downstream._add_upstream(self)
361381

362382
def disconnect(self, downstream):
363383
''' Disconnect this stream to a downstream element.
@@ -367,9 +387,9 @@ def disconnect(self, downstream):
367387
downstream: Stream
368388
The downstream stream to disconnect from
369389
'''
370-
self.downstreams.remove(downstream)
390+
self._remove_downstream(downstream)
371391

372-
downstream.upstreams.remove(self)
392+
downstream._remove_upstream(self)
373393

374394
@property
375395
def upstream(self):
@@ -792,7 +812,8 @@ def update(self, x, who=None):
792812
def _check_end(self):
793813
if self.end and self.state >= self.end:
794814
# we're done
795-
self.upstream.downstreams.remove(self)
815+
for upstream in self.upstreams:
816+
upstream._remove_downstream(self)
796817

797818

798819
@Stream.register_api()
@@ -1016,6 +1037,16 @@ def __init__(self, *upstreams, **kwargs):
10161037

10171038
Stream.__init__(self, upstreams=upstreams2, **kwargs)
10181039

1040+
def _add_upstream(self, upstream):
1041+
# Override method to handle setup of buffer for new stream
1042+
self.buffers[upstream] = deque()
1043+
super(zip, self)._add_upstream(upstream)
1044+
1045+
def _remove_upstream(self, upstream):
1046+
# Override method to handle removal of buffer for stream
1047+
self.buffers.pop(upstream)
1048+
super(zip, self)._remove_upstream(upstream)
1049+
10191050
def pack_literals(self, tup):
10201051
""" Fill buffers for literals whenever we empty them """
10211052
inp = list(tup)[::-1]
@@ -1067,6 +1098,7 @@ class combine_latest(Stream):
10671098

10681099
def __init__(self, *upstreams, **kwargs):
10691100
emit_on = kwargs.pop('emit_on', None)
1101+
self._initial_emit_on = emit_on
10701102

10711103
self.last = [None for _ in upstreams]
10721104
self.missing = set(upstreams)
@@ -1080,6 +1112,30 @@ def __init__(self, *upstreams, **kwargs):
10801112
self.emit_on = upstreams
10811113
Stream.__init__(self, upstreams=upstreams, **kwargs)
10821114

1115+
def _add_upstream(self, upstream):
1116+
# Override method to handle setup of last and missing for new stream
1117+
self.last.append(None)
1118+
self.missing.update([upstream])
1119+
super(combine_latest, self)._add_upstream(upstream)
1120+
if self._initial_emit_on is None:
1121+
self.emit_on = self.upstreams
1122+
1123+
def _remove_upstream(self, upstream):
1124+
# Override method to handle removal of last and missing for stream
1125+
if self.emit_on == upstream:
1126+
raise RuntimeError("Can't remove the ``emit_on`` stream since that"
1127+
"would cause no data to be emitted. "
1128+
"Consider adding an ``emit_on`` first by "
1129+
"running ``node.emit_on=(upstream,)`` to add "
1130+
"a new ``emit_on`` or running "
1131+
"``node.emit_on=tuple(node.upstreams)`` to "
1132+
"emit on all incoming data")
1133+
self.last.pop(self.upstreams.index(upstream))
1134+
self.missing.remove(upstream)
1135+
super(combine_latest, self)._remove_upstream(upstream)
1136+
if self._initial_emit_on is None:
1137+
self.emit_on = self.upstreams
1138+
10831139
def update(self, x, who=None):
10841140
if self.missing and who in self.missing:
10851141
self.missing.remove(who)

streamz/tests/test_core.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,5 +1207,59 @@ def start(self):
12071207
assert flag == [True]
12081208

12091209

1210+
def test_connect_zip():
1211+
a = Stream()
1212+
b = Stream()
1213+
c = Stream()
1214+
x = a.zip(b)
1215+
L = x.sink_to_list()
1216+
c.connect(x)
1217+
a.emit(1)
1218+
b.emit(1)
1219+
assert not L
1220+
c.emit(1)
1221+
assert L == [(1, 1, 1)]
1222+
1223+
1224+
def test_disconnect_zip():
1225+
a = Stream()
1226+
b = Stream()
1227+
c = Stream()
1228+
x = a.zip(b, c)
1229+
L = x.sink_to_list()
1230+
b.disconnect(x)
1231+
a.emit(1)
1232+
b.emit(1)
1233+
assert not L
1234+
c.emit(1)
1235+
assert L == [(1, 1)]
1236+
1237+
1238+
def test_connect_combine_latest():
1239+
a = Stream()
1240+
b = Stream()
1241+
c = Stream()
1242+
x = a.combine_latest(b, emit_on=a)
1243+
L = x.sink_to_list()
1244+
c.connect(x)
1245+
b.emit(1)
1246+
c.emit(1)
1247+
a.emit(1)
1248+
assert L == [(1, 1, 1)]
1249+
1250+
1251+
def test_connect_discombine_latest():
1252+
a = Stream()
1253+
b = Stream()
1254+
c = Stream()
1255+
x = a.combine_latest(b, c, emit_on=a)
1256+
L = x.sink_to_list()
1257+
c.disconnect(x)
1258+
b.emit(1)
1259+
c.emit(1)
1260+
a.emit(1)
1261+
assert L == [(1, 1)]
1262+
1263+
12101264
if sys.version_info >= (3, 5):
12111265
from streamz.tests.py3_test_core import * # noqa

0 commit comments

Comments
 (0)