1010 MultiAggregationWindowMixin ,
1111 SingleAggregationWindowMixin ,
1212 Window ,
13+ WindowAfterUpdateCallback ,
14+ WindowBeforeUpdateCallback ,
1315 WindowKeyResult ,
1416 WindowOnLateCallback ,
1517 get_window_ranges ,
@@ -30,6 +32,8 @@ def __init__(
3032 dataframe : "StreamingDataFrame" ,
3133 step_ms : Optional [int ] = None ,
3234 on_late : Optional [WindowOnLateCallback ] = None ,
35+ before_update : Optional [WindowBeforeUpdateCallback ] = None ,
36+ after_update : Optional [WindowAfterUpdateCallback ] = None ,
3337 ):
3438 super ().__init__ (
3539 name = name ,
@@ -40,6 +44,8 @@ def __init__(
4044 self ._grace_ms = grace_ms
4145 self ._step_ms = step_ms
4246 self ._on_late = on_late
47+ self ._before_update = before_update
48+ self ._after_update = after_update
4349
4450 def final (self ) -> "StreamingDataFrame" :
4551 """
@@ -69,13 +75,17 @@ def on_update(
6975 _headers : Any ,
7076 transaction : WindowedPartitionTransaction ,
7177 ):
72- self .process_window (
78+ # Process the window and get windows triggered from callbacks
79+ _ , triggered_windows = self .process_window (
7380 value = value ,
7481 key = key ,
7582 timestamp_ms = timestamp_ms ,
83+ headers = _headers ,
7684 transaction = transaction ,
7785 )
78- return []
86+ # Yield triggered windows (from before_update/after_update callbacks)
87+ for key , window in triggered_windows :
88+ yield window , key , window ["start" ], None
7989
8090 def on_watermark (
8191 _value : Any ,
@@ -133,15 +143,20 @@ def on_update(
133143 _headers : Any ,
134144 transaction : WindowedPartitionTransaction ,
135145 ):
136- updated_windows = self .process_window (
146+ # Process the window and get both updated and triggered windows
147+ updated_windows , triggered_windows = self .process_window (
137148 value = value ,
138149 key = key ,
139150 timestamp_ms = timestamp_ms ,
151+ headers = _headers ,
140152 transaction = transaction ,
141153 )
142154 # Use window start timestamp as a new record timestamp
155+ # Yield both updated and triggered windows
143156 for key , window in updated_windows :
144157 yield window , key , window ["start" ], None
158+ for key , window in triggered_windows :
159+ yield window , key , window ["start" ], None
145160
146161 def on_watermark (
147162 _value : Any ,
@@ -169,11 +184,22 @@ def process_window(
169184 value : Any ,
170185 key : Any ,
171186 timestamp_ms : int ,
187+ headers : Any ,
172188 transaction : WindowedPartitionTransaction ,
173- ) -> Iterable [WindowKeyResult ]:
189+ ) -> tuple [Iterable [WindowKeyResult ], Iterable [WindowKeyResult ]]:
190+ """
191+ Process a window update for the given value and key.
192+
193+ Returns:
194+ A tuple of (updated_windows, triggered_windows) where:
195+ - updated_windows: Windows that were updated but not expired
196+ - triggered_windows: Windows that were expired early due to before_update/after_update callbacks
197+ """
174198 state = transaction .as_state (prefix = key )
175199 duration_ms = self ._duration_ms
176200 grace_ms = self ._grace_ms
201+ before_update = self ._before_update
202+ after_update = self ._after_update
177203
178204 collect = self .collect
179205 aggregate = self .aggregate
@@ -190,6 +216,7 @@ def process_window(
190216 max_expired_window_end = latest_timestamp - grace_ms
191217 max_expired_window_start = max_expired_window_end - duration_ms
192218 updated_windows : list [WindowKeyResult ] = []
219+ triggered_windows : list [WindowKeyResult ] = []
193220 for start , end in ranges :
194221 if start <= max_expired_window_start :
195222 late_by_ms = max_expired_window_end - timestamp_ms
@@ -207,18 +234,78 @@ def process_window(
207234 # since actual values are stored separately and combined into an array
208235 # during window expiration.
209236 aggregated = None
237+
210238 if aggregate :
211239 current_value = state .get_window (start , end )
212240 if current_value is None :
213241 current_value = self ._initialize_value ()
214242
243+ # Check before_update trigger
244+ if before_update and before_update (
245+ current_value , value , key , timestamp_ms , headers
246+ ):
247+ # Get collected values for the result
248+ # Do NOT include the current value - before_update means
249+ # we expire BEFORE adding the current value
250+ collected = state .get_from_collection (start , end ) if collect else []
251+
252+ result = self ._results (current_value , collected , start , end )
253+ triggered_windows .append ((key , result ))
254+ transaction .delete_window (start , end , prefix = key )
255+ # Note: We don't delete from collection here - normal expiration
256+ # will handle cleanup for both tumbling and hopping windows
257+ continue
258+
215259 aggregated = self ._aggregate_value (current_value , value , timestamp_ms )
216- updated_windows .append (
217- (
218- key ,
219- self ._results (aggregated , [], start , end ),
220- )
221- )
260+
261+ # Check after_update trigger
262+ if after_update and after_update (
263+ aggregated , value , key , timestamp_ms , headers
264+ ):
265+ # Get collected values for the result
266+ collected = []
267+ if collect :
268+ collected = state .get_from_collection (start , end )
269+ # Add the current value that's being collected
270+ collected .append (self ._collect_value (value ))
271+
272+ result = self ._results (aggregated , collected , start , end )
273+ triggered_windows .append ((key , result ))
274+ transaction .delete_window (start , end , prefix = key )
275+ # Note: We don't delete from collection here - normal expiration
276+ # will handle cleanup for both tumbling and hopping windows
277+ continue
278+
279+ result = self ._results (aggregated , [], start , end )
280+ updated_windows .append ((key , result ))
281+ elif collect and (before_update or after_update ):
282+ # For collect-only windows, get the old collected values
283+ old_collected = state .get_from_collection (start , end )
284+
285+ # Check before_update trigger (before adding new value)
286+ if before_update and before_update (
287+ old_collected , value , key , timestamp_ms , headers
288+ ):
289+ # Expire with the current collection (WITHOUT the new value)
290+ result = self ._results (None , old_collected , start , end )
291+ triggered_windows .append ((key , result ))
292+ transaction .delete_window (start , end , prefix = key )
293+ # Note: We don't delete from collection here - normal expiration
294+ # will handle cleanup for both tumbling and hopping windows
295+ continue
296+
297+ # Check after_update trigger (conceptually after adding new value)
298+ # For collect, "after update" means after the value would be added
299+ if after_update :
300+ new_collected = [* old_collected , self ._collect_value (value )]
301+ if after_update (new_collected , value , key , timestamp_ms , headers ):
302+ result = self ._results (None , new_collected , start , end )
303+ triggered_windows .append ((key , result ))
304+ transaction .delete_window (start , end , prefix = key )
305+ # Note: We don't delete from collection here - normal expiration
306+ # will handle cleanup for both tumbling and hopping windows
307+ continue
308+
222309 state .update_window (start , end , value = aggregated , timestamp_ms = timestamp_ms )
223310
224311 if collect :
@@ -227,7 +314,7 @@ def process_window(
227314 id = timestamp_ms ,
228315 )
229316
230- return updated_windows
317+ return updated_windows , triggered_windows
231318
232319 def expire_by_partition (
233320 self ,
0 commit comments