Skip to content

Commit 090b9ba

Browse files
martindurantmrocklin
authored andcommitted
Add to_kafka stream and fix kafka testing issues (#230)
1 parent 1486318 commit 090b9ba

File tree

5 files changed

+157
-68
lines changed

5 files changed

+157
-68
lines changed

streamz/core.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,88 @@ def cb(self):
12311231
yield self._emit(x)
12321232

12331233

1234+
@Stream.register_api()
1235+
class to_kafka(Stream):
1236+
""" Writes data in the stream to Kafka
1237+
1238+
This stream accepts a string or bytes object. Call ``flush`` to ensure all
1239+
messages are pushed. Responses from Kafka are pushed downstream.
1240+
1241+
Parameters
1242+
----------
1243+
topic : string
1244+
The topic which to write
1245+
producer_config : dict
1246+
Settings to set up the stream, see
1247+
https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
1248+
https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
1249+
Examples:
1250+
bootstrap.servers: Connection string (host:port) to Kafka
1251+
1252+
Examples
1253+
--------
1254+
>>> from streamz import Stream
1255+
>>> ARGS = {'bootstrap.servers': 'localhost:9092'}
1256+
>>> source = Stream()
1257+
>>> kafka = source.map(lambda x: str(x)).to_kafka('test', ARGS)
1258+
<to_kafka>
1259+
>>> for i in range(10):
1260+
... source.emit(i)
1261+
>>> kafka.flush()
1262+
"""
1263+
def __init__(self, upstream, topic, producer_config, **kwargs):
1264+
import confluent_kafka as ck
1265+
1266+
self.topic = topic
1267+
self.producer = ck.Producer(producer_config)
1268+
1269+
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1270+
self.stopped = False
1271+
self.polltime = 0.2
1272+
self.loop.add_callback(self.poll)
1273+
self.futures = []
1274+
1275+
@gen.coroutine
1276+
def poll(self):
1277+
while not self.stopped:
1278+
# executes callbacks for any delivered data, in this thread
1279+
# if no messages were sent, nothing happens
1280+
self.producer.poll(0)
1281+
yield gen.sleep(self.polltime)
1282+
1283+
def update(self, x, who=None):
1284+
future = gen.Future()
1285+
self.futures.append(future)
1286+
1287+
@gen.coroutine
1288+
def _():
1289+
while True:
1290+
try:
1291+
# this runs asynchronously, in C-K's thread
1292+
self.producer.produce(self.topic, x, callback=self.cb)
1293+
return
1294+
except BufferError:
1295+
yield gen.sleep(self.polltime)
1296+
except Exception as e:
1297+
future.set_exception(e)
1298+
return
1299+
1300+
self.loop.add_callback(_)
1301+
return future
1302+
1303+
@gen.coroutine
1304+
def cb(self, err, msg):
1305+
future = self.futures.pop(0)
1306+
if msg is not None and msg.value() is not None:
1307+
future.set_result(None)
1308+
yield self._emit(msg.value())
1309+
else:
1310+
future.set_exception(err or msg.error())
1311+
1312+
def flush(self, timeout=-1):
1313+
self.producer.flush(timeout)
1314+
1315+
12341316
def sync(loop, func, *args, **kwargs):
12351317
"""
12361318
Run coroutine in loop running in separate thread.

streamz/sources.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from glob import glob
22
import os
3-
import weakref
43

54
import time
65
import tornado.ioloop
@@ -180,8 +179,7 @@ class from_kafka(Source):
180179
... {'bootstrap.servers': 'localhost:9092',
181180
... 'group.id': 'streamz'}) # doctest: +SKIP
182181
"""
183-
def __init__(self, topics, consumer_params, poll_interval=0.1, start=False,
184-
**kwargs):
182+
def __init__(self, topics, consumer_params, poll_interval=0.1, start=False, **kwargs):
185183
self.cpars = consumer_params
186184
self.consumer = None
187185
self.topics = topics
@@ -194,7 +192,7 @@ def __init__(self, topics, consumer_params, poll_interval=0.1, start=False,
194192
def do_poll(self):
195193
if self.consumer is not None:
196194
msg = self.consumer.poll(0)
197-
if msg and msg.value():
195+
if msg and msg.value() and msg.error() is None:
198196
return msg.value()
199197

200198
@gen.coroutine
@@ -207,26 +205,19 @@ def poll_kafka(self):
207205
yield gen.sleep(self.poll_interval)
208206
if self.stopped:
209207
break
208+
self._close_consumer()
210209

211210
def start(self):
212211
import confluent_kafka as ck
213-
import distributed
214212
if self.stopped:
215-
finalize = distributed.compatibility.finalize
216213
self.stopped = False
217-
self.loop.add_callback(self.poll_kafka)
218214
self.consumer = ck.Consumer(self.cpars)
219215
self.consumer.subscribe(self.topics)
216+
tp = ck.TopicPartition(self.topics[0], 0, 0)
220217

221-
def close(ref):
222-
ob = ref()
223-
if ob is not None and ob.consumer is not None:
224-
consumer = ob.consumer
225-
ob.consumer = None
226-
consumer.unsubscribe()
227-
consumer.close() # may raise with latest ck, that's OK
228-
229-
finalize(self, close, weakref.ref(self))
218+
# blocks for consumer thread to come up
219+
self.consumer.get_watermark_offsets(tp)
220+
self.loop.add_callback(self.poll_kafka)
230221

231222
def _close_consumer(self):
232223
if self.consumer is not None:
@@ -253,36 +244,43 @@ def __init__(self, topic, consumer_params, poll_interval='1s',
253244
@gen.coroutine
254245
def poll_kafka(self):
255246
import confluent_kafka as ck
256-
consumer = ck.Consumer(self.consumer_params)
257247

258248
try:
259249
while not self.stopped:
260250
out = []
261-
262251
for partition in range(self.npartitions):
263252
tp = ck.TopicPartition(self.topic, partition, 0)
264253
try:
265-
low, high = consumer.get_watermark_offsets(tp,
266-
timeout=0.1)
254+
low, high = self.consumer.get_watermark_offsets(
255+
tp, timeout=0.1)
267256
except (RuntimeError, ck.KafkaException):
268257
continue
269258
current_position = self.positions[partition]
270259
lowest = max(current_position, low)
271-
out.append((self.consumer_params, self.topic, partition,
272-
lowest, high - 1))
273-
self.positions[partition] = high
260+
if high > lowest:
261+
out.append((self.consumer_params, self.topic, partition,
262+
lowest, high - 1))
263+
self.positions[partition] = high
274264

275265
for part in out:
276266
yield self._emit(part)
277267

278268
else:
279269
yield gen.sleep(self.poll_interval)
280270
finally:
281-
consumer.close()
271+
self.consumer.unsubscribe()
272+
self.consumer.close()
282273

283274
def start(self):
284-
self.stopped = False
285-
self.loop.add_callback(self.poll_kafka)
275+
import confluent_kafka as ck
276+
if self.stopped:
277+
self.consumer = ck.Consumer(self.consumer_params)
278+
self.stopped = False
279+
tp = ck.TopicPartition(self.topic, 0, 0)
280+
281+
# blocks for consumer thread to come up
282+
self.consumer.get_watermark_offsets(tp)
283+
self.loop.add_callback(self.poll_kafka)
286284

287285

288286
@Stream.register_api(staticmethod)
@@ -352,7 +350,7 @@ def get_message_batch(kafka_params, topic, partition, low, high, timeout=None):
352350
try:
353351
while True:
354352
msg = consumer.poll(0)
355-
if msg and msg.value():
353+
if msg and msg.value() and msg.error() is None:
356354
if high >= msg.offset():
357355
out.append(msg.value())
358356
if high <= msg.offset():

streamz/tests/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,8 +942,8 @@ def dont_test_stream_kwargs(clean): # noqa: F811
942942
sin.emit(1)
943943

944944

945-
@pytest.fixture # noqa: F811
946-
def thread(loop):
945+
@pytest.fixture
946+
def thread(loop): # noqa: F811
947947
from threading import Thread, Event
948948
thread = Thread(target=loop.start)
949949
thread.daemon = True

streamz/tests/test_dask.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def test_zip(c, s, a, b):
7373
assert L == [(1, 'a'), (2, 'b')]
7474

7575

76-
@pytest.mark.slow # noqa: F811
77-
def test_sync(loop):
76+
@pytest.mark.slow
77+
def test_sync(loop): # noqa: F811
7878
with cluster() as (s, [a, b]):
7979
with Client(s['address'], loop=loop) as client: # noqa: F841
8080
source = Stream()
@@ -90,8 +90,8 @@ def f():
9090
assert L == list(map(inc, range(10)))
9191

9292

93-
@pytest.mark.slow # noqa: F811
94-
def test_sync_2(loop):
93+
@pytest.mark.slow
94+
def test_sync_2(loop): # noqa: F811
9595
with cluster() as (s, [a, b]):
9696
with Client(s['address'], loop=loop): # noqa: F841
9797
source = Stream()
@@ -131,8 +131,8 @@ def test_buffer(c, s, a, b):
131131
assert source.loop == c.loop
132132

133133

134-
@pytest.mark.slow # noqa: F811
135-
def test_buffer_sync(loop):
134+
@pytest.mark.slow
135+
def test_buffer_sync(loop): # noqa: F811
136136
with cluster() as (s, [a, b]):
137137
with Client(s['address'], loop=loop) as c: # noqa: F841
138138
source = Stream()
@@ -155,9 +155,9 @@ def test_buffer_sync(loop):
155155
assert L == list(map(inc, range(10)))
156156

157157

158-
@pytest.mark.xfail(reason='') # noqa: F811
158+
@pytest.mark.xfail(reason='')
159159
@pytest.mark.slow
160-
def test_stream_shares_client_loop(loop):
160+
def test_stream_shares_client_loop(loop): # noqa: F811
161161
with cluster() as (s, [a, b]):
162162
with Client(s['address'], loop=loop) as client: # noqa: F841
163163
source = Stream()

0 commit comments

Comments
 (0)