Skip to content

Commit 11e2b2d

Browse files
author
chinmaychandak
committed
2 parents 48072af + 5443517 commit 11e2b2d

File tree

2 files changed

+151
-5
lines changed

2 files changed

+151
-5
lines changed

streamz/sources.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from glob import glob
2+
import json
23
import os
34

45
import time
@@ -453,11 +454,12 @@ def _close_consumer(self):
453454
class FromKafkaBatched(Stream):
454455
"""Base class for both local and cluster-based batched kafka processing"""
455456
def __init__(self, topic, consumer_params, poll_interval='1s',
456-
npartitions=1, **kwargs):
457+
npartitions=1, checkpointing=None, **kwargs):
457458
self.consumer_params = consumer_params
458459
self.topic = topic
459460
self.npartitions = npartitions
460461
self.positions = [0] * npartitions
462+
self.checkpointing = checkpointing
461463
self.poll_interval = convert_interval(poll_interval)
462464
self.stopped = True
463465

@@ -470,18 +472,40 @@ def poll_kafka(self):
470472
try:
471473
while not self.stopped:
472474
out = []
475+
476+
latest_checkpoint = {}
477+
if self.checkpointing is not None:
478+
if not os.path.exists(self.checkpointing):
479+
os.makedirs(self.checkpointing)
480+
topic_path = self.checkpointing + '/' + self.topic
481+
if not os.path.exists(topic_path):
482+
os.makedirs(topic_path)
483+
checkpoints_list = os.listdir(topic_path)
484+
if len(checkpoints_list) > 0:
485+
previous_checkpoint = max(checkpoints_list)
486+
with open(topic_path + '/' + previous_checkpoint, 'r') as fr:
487+
latest_checkpoint = json.loads(fr.readlines()[-1])
488+
fr.close()
489+
473490
for partition in range(self.npartitions):
474491
tp = ck.TopicPartition(self.topic, partition, 0)
475492
try:
476493
low, high = self.consumer.get_watermark_offsets(
477494
tp, timeout=0.1)
478495
except (RuntimeError, ck.KafkaException):
479496
continue
497+
480498
current_position = self.positions[partition]
499+
group = self.consumer_params['group.id']
500+
501+
if group in latest_checkpoint.keys():
502+
if str(partition) in latest_checkpoint[group].keys():
503+
current_position = latest_checkpoint[group][str(partition)]
504+
481505
lowest = max(current_position, low)
482506
if high > lowest:
483507
out.append((self.consumer_params, self.topic, partition,
484-
lowest, high - 1))
508+
lowest, high - 1, self.checkpointing))
485509
self.positions[partition] = high
486510

487511
for part in out:
@@ -507,7 +531,8 @@ def start(self):
507531

508532
@Stream.register_api(staticmethod)
509533
def from_kafka_batched(topic, consumer_params, poll_interval='1s',
510-
npartitions=1, start=False, dask=False, **kwargs):
534+
npartitions=1, start=False, dask=False,
535+
checkpointing=None, **kwargs):
511536
""" Get messages from Kafka in batches
512537
513538
Uses the confluent-kafka library,
@@ -549,7 +574,8 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
549574
kwargs['loop'] = default_client().loop
550575
source = FromKafkaBatched(topic, consumer_params,
551576
poll_interval=poll_interval,
552-
npartitions=npartitions, **kwargs)
577+
npartitions=npartitions,
578+
checkpointing=checkpointing, **kwargs)
553579
if dask:
554580
source = source.scatter()
555581

@@ -559,7 +585,41 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
559585
return source.starmap(get_message_batch)
560586

561587

562-
def get_message_batch(kafka_params, topic, partition, low, high, timeout=None):
588+
def add_checkpoint(group, checkpoint, path):
589+
topic = checkpoint.topic
590+
partition = checkpoint.partition
591+
offset = checkpoint.offset
592+
latest_checkpoint = {}
593+
previous_checkpoint = None
594+
if not os.path.exists(path):
595+
os.makedirs(path)
596+
path = path + '/' + topic
597+
if not os.path.exists(path):
598+
os.makedirs(path)
599+
checkpoints_list = os.listdir(path)
600+
if len(checkpoints_list) > 0:
601+
previous_checkpoint = max(checkpoints_list)
602+
with open(path + '/' + previous_checkpoint, 'r') as fr:
603+
latest_checkpoint = json.loads(fr.readlines()[0])
604+
fr.close()
605+
#Only maintain the last 5 checkpoints
606+
if len(checkpoints_list) == 5:
607+
os.system('rm -rf ' + path + '/' + min(checkpoints_list))
608+
if group not in latest_checkpoint.keys():
609+
latest_checkpoint[group] = {}
610+
latest_checkpoint[group][partition] = offset
611+
print(latest_checkpoint)
612+
if previous_checkpoint is None:
613+
new_checkpoint = '1.txt'
614+
else:
615+
previous_batch = int(previous_checkpoint.split('.')[0])
616+
new_checkpoint = str(previous_batch + 1) + '.txt'
617+
with open(path + '/' + new_checkpoint, 'a+') as fw:
618+
fw.write(json.dumps(latest_checkpoint) + '\n')
619+
fw.close()
620+
621+
622+
def get_message_batch(kafka_params, topic, partition, low, high, checkpointing, timeout=None):
563623
"""Fetch a batch of kafka messages in given topic/partition
564624
565625
This will block until messages are available, or timeout is reached.
@@ -583,5 +643,8 @@ def get_message_batch(kafka_params, topic, partition, low, high, timeout=None):
583643
if timeout is not None and time.time() - t0 > timeout:
584644
break
585645
finally:
646+
if checkpointing is not None:
647+
checkpoint = consumer.commit(asynchronous=False)
648+
add_checkpoint(kafka_params['group.id'], checkpoint[0], checkpointing)
586649
consumer.close()
587650
return out

streamz/tests/test_kafka.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import requests
77
import shlex
88
import subprocess
9+
import time
910
from tornado import gen
1011

1112
from ..core import Stream
@@ -200,6 +201,88 @@ def test_kafka_batch():
200201
stream.upstream.stopped = True
201202

202203

204+
def test_kafka_batch_checkpointing():
205+
bootstrap_servers = 'localhost:9092'
206+
ARGS = {'bootstrap.servers': bootstrap_servers,
207+
'group.id': 'streamz-test'}
208+
ARGS1 = {'bootstrap.servers': bootstrap_servers,
209+
'group.id': 'streamz-test'}
210+
ARGS2 = {'bootstrap.servers': bootstrap_servers,
211+
'group.id': 'streamz-test'}
212+
ARGS3 = {'bootstrap.servers': bootstrap_servers,
213+
'group.id': 'streamz-test'}
214+
with kafka_service() as kafka:
215+
kafka, TOPIC = kafka
216+
for i in range(10):
217+
kafka.produce(TOPIC, b'value-%d' % i)
218+
kafka.flush()
219+
220+
stream = Stream.from_kafka_batched(TOPIC, ARGS, checkpointing='custreamz_checkpoints')
221+
out = stream.sink_to_list()
222+
stream.start()
223+
wait_for(lambda: any(out) and out[-1][-1] == b'value-9', 10, period=0.2)
224+
assert out[-1][-1] == b'value-9'
225+
stream.upstream.stopped = True
226+
227+
stream1 = Stream.from_kafka_batched(TOPIC, ARGS1, checkpointing=None)
228+
out1 = stream1.sink_to_list()
229+
stream1.start()
230+
wait_for(lambda: any(out1) and out1[-1][-1] == b'value-9', 10, period=0.2)
231+
assert out[-1][-1] == b'value-9'
232+
stream1.upstream.stopped = True
233+
234+
stream2 = Stream.from_kafka_batched(TOPIC, ARGS2, checkpointing='custreamz_checkpoints1')
235+
out2 = stream2.sink_to_list()
236+
stream2.start()
237+
wait_for(lambda: any(out2) and out2[-1][-1] == b'value-9', 10, period=0.2)
238+
assert out[-1][-1] == b'value-9'
239+
stream2.upstream.stopped = True
240+
241+
for i in range(10, 20):
242+
kafka.produce(TOPIC, b'value-%d' % i)
243+
kafka.flush()
244+
245+
stream3 = Stream.from_kafka_batched(TOPIC, ARGS3, checkpointing='custreamz_checkpoints')
246+
out3 = stream3.sink_to_list()
247+
stream3.start()
248+
wait_for(lambda: any(out3) and out3[-1][0] == b'value-10' and out3[-1][-1] == b'value-19', 10, period=0.2)
249+
250+
for i in range(20, 25):
251+
kafka.produce(TOPIC, b'value-%d' % i)
252+
kafka.flush()
253+
time.sleep(5)
254+
checkpoints_list = os.listdir('custreamz_checkpoints/' + TOPIC)
255+
assert len(checkpoints_list) == 3
256+
257+
for i in range(25, 30):
258+
kafka.produce(TOPIC, b'value-%d' % i)
259+
kafka.flush()
260+
time.sleep(5)
261+
checkpoints_list = os.listdir('custreamz_checkpoints/' + TOPIC)
262+
assert len(checkpoints_list) == 4
263+
264+
for i in range(30, 35):
265+
kafka.produce(TOPIC, b'value-%d' % i)
266+
kafka.flush()
267+
time.sleep(5)
268+
checkpoints_list = os.listdir('custreamz_checkpoints/' + TOPIC)
269+
assert len(checkpoints_list) == 5
270+
271+
for i in range(35, 40):
272+
kafka.produce(TOPIC, b'value-%d' % i)
273+
kafka.flush()
274+
time.sleep(5)
275+
checkpoints_list = os.listdir('custreamz_checkpoints/' + TOPIC)
276+
assert len(checkpoints_list) == 5
277+
278+
for i in range(40, 45):
279+
kafka.produce(TOPIC, b'value-%d' % i)
280+
kafka.flush()
281+
time.sleep(5)
282+
checkpoints_list = os.listdir('custreamz_checkpoints/' + TOPIC)
283+
assert len(checkpoints_list) == 5
284+
285+
203286
@gen_cluster(client=True, timeout=60)
204287
def test_kafka_dask_batch(c, s, w1, w2):
205288
j = random.randint(0, 10000)

0 commit comments

Comments
 (0)