11from glob import glob
2+ import json
23import os
34
45import time
@@ -453,11 +454,12 @@ def _close_consumer(self):
453454class FromKafkaBatched (Stream ):
454455 """Base class for both local and cluster-based batched kafka processing"""
455456 def __init__ (self , topic , consumer_params , poll_interval = '1s' ,
456- npartitions = 1 , ** kwargs ):
457+ npartitions = 1 , checkpointing = None , ** kwargs ):
457458 self .consumer_params = consumer_params
458459 self .topic = topic
459460 self .npartitions = npartitions
460461 self .positions = [0 ] * npartitions
462+ self .checkpointing = checkpointing
461463 self .poll_interval = convert_interval (poll_interval )
462464 self .stopped = True
463465
@@ -470,18 +472,40 @@ def poll_kafka(self):
470472 try :
471473 while not self .stopped :
472474 out = []
475+
476+ latest_checkpoint = {}
477+ if self .checkpointing is not None :
478+ if not os .path .exists (self .checkpointing ):
479+ os .makedirs (self .checkpointing )
480+ topic_path = self .checkpointing + '/' + self .topic
481+ if not os .path .exists (topic_path ):
482+ os .makedirs (topic_path )
483+ checkpoints_list = os .listdir (topic_path )
484+ if len (checkpoints_list ) > 0 :
485+ previous_checkpoint = max (checkpoints_list )
486+ with open (topic_path + '/' + previous_checkpoint , 'r' ) as fr :
487+ latest_checkpoint = json .loads (fr .readlines ()[- 1 ])
488+ fr .close ()
489+
473490 for partition in range (self .npartitions ):
474491 tp = ck .TopicPartition (self .topic , partition , 0 )
475492 try :
476493 low , high = self .consumer .get_watermark_offsets (
477494 tp , timeout = 0.1 )
478495 except (RuntimeError , ck .KafkaException ):
479496 continue
497+
480498 current_position = self .positions [partition ]
499+ group = self .consumer_params ['group.id' ]
500+
501+ if group in latest_checkpoint .keys ():
502+ if str (partition ) in latest_checkpoint [group ].keys ():
503+ current_position = latest_checkpoint [group ][str (partition )]
504+
481505 lowest = max (current_position , low )
482506 if high > lowest :
483507 out .append ((self .consumer_params , self .topic , partition ,
484- lowest , high - 1 ))
508+ lowest , high - 1 , self . checkpointing ))
485509 self .positions [partition ] = high
486510
487511 for part in out :
@@ -507,7 +531,8 @@ def start(self):
507531
508532@Stream .register_api (staticmethod )
509533def from_kafka_batched (topic , consumer_params , poll_interval = '1s' ,
510- npartitions = 1 , start = False , dask = False , ** kwargs ):
534+ npartitions = 1 , start = False , dask = False ,
535+ checkpointing = None , ** kwargs ):
511536 """ Get messages from Kafka in batches
512537
513538 Uses the confluent-kafka library,
@@ -549,7 +574,8 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
549574 kwargs ['loop' ] = default_client ().loop
550575 source = FromKafkaBatched (topic , consumer_params ,
551576 poll_interval = poll_interval ,
552- npartitions = npartitions , ** kwargs )
577+ npartitions = npartitions ,
578+ checkpointing = checkpointing , ** kwargs )
553579 if dask :
554580 source = source .scatter ()
555581
@@ -559,7 +585,41 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
559585 return source .starmap (get_message_batch )
560586
561587
562- def get_message_batch (kafka_params , topic , partition , low , high , timeout = None ):
588+ def add_checkpoint (group , checkpoint , path ):
589+ topic = checkpoint .topic
590+ partition = checkpoint .partition
591+ offset = checkpoint .offset
592+ latest_checkpoint = {}
593+ previous_checkpoint = None
594+ if not os .path .exists (path ):
595+ os .makedirs (path )
596+ path = path + '/' + topic
597+ if not os .path .exists (path ):
598+ os .makedirs (path )
599+ checkpoints_list = os .listdir (path )
600+ if len (checkpoints_list ) > 0 :
601+ previous_checkpoint = max (checkpoints_list )
602+ with open (path + '/' + previous_checkpoint , 'r' ) as fr :
603+ latest_checkpoint = json .loads (fr .readlines ()[0 ])
604+ fr .close ()
605+ #Only maintain the last 5 checkpoints
606+ if len (checkpoints_list ) == 5 :
607+ os .system ('rm -rf ' + path + '/' + min (checkpoints_list ))
608+ if group not in latest_checkpoint .keys ():
609+ latest_checkpoint [group ] = {}
610+ latest_checkpoint [group ][partition ] = offset
611+ print (latest_checkpoint )
612+ if previous_checkpoint is None :
613+ new_checkpoint = '1.txt'
614+ else :
615+ previous_batch = int (previous_checkpoint .split ('.' )[0 ])
616+ new_checkpoint = str (previous_batch + 1 ) + '.txt'
617+ with open (path + '/' + new_checkpoint , 'a+' ) as fw :
618+ fw .write (json .dumps (latest_checkpoint ) + '\n ' )
619+ fw .close ()
620+
621+
622+ def get_message_batch (kafka_params , topic , partition , low , high , checkpointing , timeout = None ):
563623 """Fetch a batch of kafka messages in given topic/partition
564624
565625 This will block until messages are available, or timeout is reached.
@@ -583,5 +643,8 @@ def get_message_batch(kafka_params, topic, partition, low, high, timeout=None):
583643 if timeout is not None and time .time () - t0 > timeout :
584644 break
585645 finally :
646+ if checkpointing is not None :
647+ checkpoint = consumer .commit (asynchronous = False )
648+ add_checkpoint (kafka_params ['group.id' ], checkpoint [0 ], checkpointing )
586649 consumer .close ()
587650 return out
0 commit comments