@@ -95,6 +95,10 @@ def __init__(
9595 # the oldest partitions can be efficiently removed, maintaining the most recent partitions.
9696 self ._cursor_per_partition : OrderedDict [str , ConcurrentCursor ] = OrderedDict ()
9797 self ._semaphore_per_partition : OrderedDict [str , threading .Semaphore ] = OrderedDict ()
98+
99+ # Parent-state tracking: store each partition’s parent state in creation order
100+ self ._partition_parent_state_map : OrderedDict [str , Mapping [str , Any ]] = OrderedDict ()
101+
98102 self ._finished_partitions : set [str ] = set ()
99103 self ._lock = threading .Lock ()
100104 self ._timer = Timer ()
@@ -155,11 +159,62 @@ def close_partition(self, partition: Partition) -> None:
155159 and self ._semaphore_per_partition [partition_key ]._value == 0
156160 ):
157161 self ._update_global_cursor (cursor .state [self .cursor_field .cursor_field_key ])
158- self ._emit_state_message ()
162+
163+ self ._check_and_update_parent_state ()
164+
165+ self ._emit_state_message ()
166+
167+ def _check_and_update_parent_state (self ) -> None :
168+ """
169+ Pop the leftmost partition state from _partition_parent_state_map only if
170+ *all partitions* up to (and including) that partition key in _semaphore_per_partition
171+ are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
172+ Additionally, delete finished semaphores with a value of 0 to free up memory,
173+ as they are only needed to track errors and completion status.
174+ """
175+ last_closed_state = None
176+
177+ while self ._partition_parent_state_map :
178+ # Look at the earliest partition key in creation order
179+ earliest_key = next (iter (self ._partition_parent_state_map ))
180+
181+ # Verify ALL partitions from the left up to earliest_key are finished
182+ all_left_finished = True
183+ for p_key , sem in list (
184+ self ._semaphore_per_partition .items ()
185+ ): # Use list to allow modification during iteration
186+ # If any earlier partition is still not finished, we must stop
187+ if p_key not in self ._finished_partitions or sem ._value != 0 :
188+ all_left_finished = False
189+ break
190+ # Once we've reached earliest_key in the semaphore order, we can stop checking
191+ if p_key == earliest_key :
192+ break
193+
194+ # If the partitions up to earliest_key are not all finished, break the while-loop
195+ if not all_left_finished :
196+ break
197+
198+ # Pop the leftmost entry from parent-state map
199+ _ , closed_parent_state = self ._partition_parent_state_map .popitem (last = False )
200+ last_closed_state = closed_parent_state
201+
202+ # Clean up finished semaphores with value 0 up to and including earliest_key
203+ for p_key in list (self ._semaphore_per_partition .keys ()):
204+ sem = self ._semaphore_per_partition [p_key ]
205+ if p_key in self ._finished_partitions and sem ._value == 0 :
206+ del self ._semaphore_per_partition [p_key ]
207+ logger .debug (f"Deleted finished semaphore for partition { p_key } with value 0" )
208+ if p_key == earliest_key :
209+ break
210+
211+ # Update _parent_state if we popped at least one partition
212+ if last_closed_state is not None :
213+ self ._parent_state = last_closed_state
159214
160215 def ensure_at_least_one_state_emitted (self ) -> None :
161216 """
162- The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
217+ The platform expects at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
163218 called.
164219 """
165220 if not any (
@@ -201,13 +256,19 @@ def stream_slices(self) -> Iterable[StreamSlice]:
201256
202257 slices = self ._partition_router .stream_slices ()
203258 self ._timer .start ()
204- for partition in slices :
205- yield from self ._generate_slices_from_partition (partition )
259+ for partition , last , parent_state in iterate_with_last_flag_and_state (
260+ slices , self ._partition_router .get_stream_state
261+ ):
262+ yield from self ._generate_slices_from_partition (partition , parent_state )
206263
207- def _generate_slices_from_partition (self , partition : StreamSlice ) -> Iterable [StreamSlice ]:
264+ def _generate_slices_from_partition (
265+ self , partition : StreamSlice , parent_state : Mapping [str , Any ]
266+ ) -> Iterable [StreamSlice ]:
208267 # Ensure the maximum number of partitions is not exceeded
209268 self ._ensure_partition_limit ()
210269
270+ partition_key = self ._to_partition_key (partition .partition )
271+
211272 cursor = self ._cursor_per_partition .get (self ._to_partition_key (partition .partition ))
212273 if not cursor :
213274 cursor = self ._create_cursor (
@@ -216,18 +277,26 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St
216277 )
217278 with self ._lock :
218279 self ._number_of_partitions += 1
219- self ._cursor_per_partition [self ._to_partition_key (partition .partition )] = cursor
220- self ._semaphore_per_partition [self ._to_partition_key (partition .partition )] = (
221- threading .Semaphore (0 )
222- )
280+ self ._cursor_per_partition [partition_key ] = cursor
281+ self ._semaphore_per_partition [partition_key ] = threading .Semaphore (0 )
282+
283+ with self ._lock :
284+ if (
285+ len (self ._partition_parent_state_map ) == 0
286+ or self ._partition_parent_state_map [
287+ next (reversed (self ._partition_parent_state_map ))
288+ ]
289+ != parent_state
290+ ):
291+ self ._partition_parent_state_map [partition_key ] = deepcopy (parent_state )
223292
224293 for cursor_slice , is_last_slice , _ in iterate_with_last_flag_and_state (
225294 cursor .stream_slices (),
226295 lambda : None ,
227296 ):
228- self ._semaphore_per_partition [self . _to_partition_key ( partition . partition ) ].release ()
297+ self ._semaphore_per_partition [partition_key ].release ()
229298 if is_last_slice :
230- self ._finished_partitions .add (self . _to_partition_key ( partition . partition ) )
299+ self ._finished_partitions .add (partition_key )
231300 yield StreamSlice (
232301 partition = partition , cursor_slice = cursor_slice , extra_fields = partition .extra_fields
233302 )
@@ -257,9 +326,9 @@ def _ensure_partition_limit(self) -> None:
257326 while len (self ._cursor_per_partition ) > self .DEFAULT_MAX_PARTITIONS_NUMBER - 1 :
258327 # Try removing finished partitions first
259328 for partition_key in list (self ._cursor_per_partition .keys ()):
260- if (
261- partition_key in self ._finished_partitions
262- and self ._semaphore_per_partition [partition_key ]._value == 0
329+ if partition_key in self . _finished_partitions and (
330+ partition_key not in self ._semaphore_per_partition
331+ or self ._semaphore_per_partition [partition_key ]._value == 0
263332 ):
264333 oldest_partition = self ._cursor_per_partition .pop (
265334 partition_key
@@ -338,9 +407,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
338407 self ._cursor_per_partition [self ._to_partition_key (state ["partition" ])] = (
339408 self ._create_cursor (state ["cursor" ])
340409 )
341- self ._semaphore_per_partition [self ._to_partition_key (state ["partition" ])] = (
342- threading .Semaphore (0 )
343- )
344410
345411 # set default state for missing partitions if it is per partition with fallback to global
346412 if self ._GLOBAL_STATE_KEY in stream_state :
0 commit comments