Skip to content

Commit 985c623

Browse files
committed
Integrate with before/after update triggers
1 parent 940fbf4 commit 985c623

File tree

17 files changed

+939
-71
lines changed

17 files changed

+939
-71
lines changed

quixstreams/app.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from .platforms.quix.env import QUIX_ENVIRONMENT
4747
from .processing import ProcessingContext
48-
from .processing.watermarking import WatermarkManager
48+
from .processing.watermarking import WatermarkManager, WatermarkMessage
4949
from .runtracker import RunTracker
5050
from .sinks import SinkManager
5151
from .sources import BaseSource, SourceException, SourceManager
@@ -1008,7 +1008,9 @@ def _process_message(self, dataframe_composed: dict[str, VoidExecutor]):
10081008
)
10091009

10101010
if topic_name == self._watermark_manager.watermarks_topic.name:
1011-
watermark = self._watermark_manager.receive(message=first_row.value)
1011+
watermark = self._watermark_manager.receive(
1012+
message=cast(WatermarkMessage, first_row.value)
1013+
)
10121014
if watermark is None:
10131015
return
10141016

@@ -1073,12 +1075,12 @@ def _process_message(self, dataframe_composed: dict[str, VoidExecutor]):
10731075

10741076
# Store the message offset after it's successfully processed
10751077
self._processing_context.store_offset(
1076-
topic=topic_name, partition=partition, offset=offset
1078+
topic=topic_name, partition=partition, offset=offset or 0
10771079
)
10781080
self._run_tracker.set_message_consumed(True)
10791081

10801082
if self._on_message_processed is not None:
1081-
self._on_message_processed(topic_name, partition, offset)
1083+
self._on_message_processed(topic_name, partition, offset or 0)
10821084

10831085
def _on_assign(self, _, topic_partitions: List[TopicPartition]):
10841086
"""
@@ -1104,6 +1106,7 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]):
11041106
)
11051107
for i in range(
11061108
self._watermark_manager.watermarks_topic.broker_config.num_partitions
1109+
or 1
11071110
)
11081111
]
11091112
# TODO: The set is used because the watermark tp can already be present in the "topic_partitions"

quixstreams/core/stream/functions/apply.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def wrapper(
4848
timestamp: int,
4949
headers: Any,
5050
is_watermark: bool = False,
51-
on_watermark=self.on_watermark,
5251
) -> None:
5352
# Execute a function on a single value and wrap results into a list
5453
# to expand them downstream

quixstreams/core/stream/functions/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Any
2+
from typing import Any, Optional
33

44
from quixstreams.utils.pickle import pickle_copier
55

@@ -18,9 +18,11 @@ class StreamFunction(abc.ABC):
1818

1919
expand: bool = False
2020

21-
def __init__(self, func: StreamCallback):
21+
def __init__(
22+
self, func: StreamCallback, on_watermark: Optional[StreamCallback] = None
23+
):
2224
self.func = func
23-
self.on_watermark = None
25+
self.on_watermark = on_watermark
2426

2527
@abc.abstractmethod
2628
def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:

quixstreams/core/stream/functions/transform.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,14 @@ def wrapper(
6565
timestamp: int,
6666
headers: Any,
6767
is_watermark: bool = False,
68-
on_watermark=self.on_watermark,
6968
):
7069
if is_watermark:
71-
if on_watermark is not None:
70+
if self.on_watermark is not None:
7271
# React on the new watermark if "on_watermark" is defined
73-
result = self.on_watermark(None, None, timestamp, ())
72+
watermark_func = cast(
73+
TransformExpandedCallback, self.on_watermark
74+
)
75+
result = watermark_func(None, None, timestamp, ())
7476
for new_value, new_key, new_timestamp, new_headers in result:
7577
child_executor(
7678
new_value,
@@ -102,13 +104,13 @@ def wrapper(
102104
timestamp: int,
103105
headers: Any,
104106
is_watermark: bool = False,
105-
on_watermark=self.on_watermark,
106107
):
107108
if is_watermark:
108-
if on_watermark is not None:
109+
if self.on_watermark is not None:
109110
# React on the new watermark if "on_watermark" is defined
110-
new_value, new_key, new_timestamp, new_headers = (
111-
self.on_watermark(None, None, timestamp, ())
111+
watermark_func = cast(TransformCallback, self.on_watermark)
112+
new_value, new_key, new_timestamp, new_headers = watermark_func(
113+
None, None, timestamp, ()
112114
)
113115
child_executor(
114116
new_value,

quixstreams/dataframe/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,7 @@ def _sink_callback(
17081708
headers=headers,
17091709
partition=ctx.partition,
17101710
topic=ctx.topic,
1711-
offset=ctx.offset,
1711+
offset=ctx.offset or 0,
17121712
)
17131713

17141714
# uses apply without returning to make this operation terminal

quixstreams/dataframe/windows/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
WindowResult: TypeAlias = dict[str, Any]
3434
WindowKeyResult: TypeAlias = tuple[Any, WindowResult]
3535
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
36+
WindowBeforeUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
37+
WindowAfterUpdateCallback: TypeAlias = Callable[[Any, Any, Any, int, Any], bool]
3638

3739
WindowAggregateFunc = Callable[[Any, Any], Any]
3840

@@ -58,6 +60,25 @@ def __init__(
5860
def name(self) -> str:
5961
return self._name
6062

63+
@abstractmethod
64+
def process_window(
65+
self,
66+
value: Any,
67+
key: Any,
68+
timestamp_ms: int,
69+
headers: Any,
70+
transaction: WindowedPartitionTransaction,
71+
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
72+
"""
73+
Process a window update for the given value and key.
74+
75+
Returns:
76+
A tuple of (updated_windows, triggered_windows) where:
77+
- updated_windows: Windows that were updated but not expired
78+
- triggered_windows: Windows that were expired early due to before_update/after_update callbacks
79+
"""
80+
pass
81+
6182
def register_store(self) -> None:
6283
TopicManager.ensure_topics_copartitioned(*self._dataframe.topics)
6384
# Create a config for the changelog topic based on the underlying SDF topics
@@ -126,6 +147,7 @@ def final(self) -> "StreamingDataFrame":
126147
If some message keys appear irregularly in the stream, the latest windows
127148
can remain unprocessed until the message the same key is received.
128149
"""
150+
...
129151

130152
@abstractmethod
131153
def current(self) -> "StreamingDataFrame":

quixstreams/dataframe/windows/count_based.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def window_callback(
9090
value=value,
9191
key=key,
9292
timestamp_ms=timestamp_ms,
93+
headers=_headers,
9394
transaction=transaction,
9495
)
9596
# Use window start timestamp as a new record timestamp
@@ -135,6 +136,7 @@ def window_callback(
135136
value=value,
136137
key=key,
137138
timestamp_ms=timestamp_ms,
139+
headers=_headers,
138140
transaction=transaction,
139141
)
140142

quixstreams/dataframe/windows/sliding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def process_window(
1818
timestamp_ms: int,
1919
headers: Any,
2020
transaction: WindowedPartitionTransaction,
21-
) -> Iterable[WindowKeyResult]:
21+
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
2222
"""
2323
The algorithm is based on the concept that each message
2424
is associated with a left and a right window.
@@ -87,7 +87,7 @@ def process_window(
8787
timestamp_ms=timestamp_ms,
8888
late_by_ms=max_expired_window_end - timestamp_ms,
8989
)
90-
return []
90+
return [], []
9191

9292
right_start = timestamp_ms + 1
9393
right_end = right_start + duration
@@ -256,7 +256,9 @@ def process_window(
256256
if collect:
257257
state.add_to_collection(value=self._collect_value(value), id=timestamp_ms)
258258

259-
return reversed(updated_windows)
259+
# Sliding windows don't support before_update/after_update callbacks yet,
260+
# so triggered_windows is always empty
261+
return reversed(updated_windows), []
260262

261263
def expire_by_partition(
262264
self,

quixstreams/dataframe/windows/time_based.py

Lines changed: 98 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
MultiAggregationWindowMixin,
1111
SingleAggregationWindowMixin,
1212
Window,
13+
WindowAfterUpdateCallback,
14+
WindowBeforeUpdateCallback,
1315
WindowKeyResult,
1416
WindowOnLateCallback,
1517
get_window_ranges,
@@ -30,6 +32,8 @@ def __init__(
3032
dataframe: "StreamingDataFrame",
3133
step_ms: Optional[int] = None,
3234
on_late: Optional[WindowOnLateCallback] = None,
35+
before_update: Optional[WindowBeforeUpdateCallback] = None,
36+
after_update: Optional[WindowAfterUpdateCallback] = None,
3337
):
3438
super().__init__(
3539
name=name,
@@ -40,6 +44,8 @@ def __init__(
4044
self._grace_ms = grace_ms
4145
self._step_ms = step_ms
4246
self._on_late = on_late
47+
self._before_update = before_update
48+
self._after_update = after_update
4349

4450
def final(self) -> "StreamingDataFrame":
4551
"""
@@ -69,13 +75,17 @@ def on_update(
6975
_headers: Any,
7076
transaction: WindowedPartitionTransaction,
7177
):
72-
self.process_window(
78+
# Process the window and get windows triggered from callbacks
79+
_, triggered_windows = self.process_window(
7380
value=value,
7481
key=key,
7582
timestamp_ms=timestamp_ms,
83+
headers=_headers,
7684
transaction=transaction,
7785
)
78-
return []
86+
# Yield triggered windows (from before_update/after_update callbacks)
87+
for key, window in triggered_windows:
88+
yield window, key, window["start"], None
7989

8090
def on_watermark(
8191
_value: Any,
@@ -133,15 +143,20 @@ def on_update(
133143
_headers: Any,
134144
transaction: WindowedPartitionTransaction,
135145
):
136-
updated_windows = self.process_window(
146+
# Process the window and get both updated and triggered windows
147+
updated_windows, triggered_windows = self.process_window(
137148
value=value,
138149
key=key,
139150
timestamp_ms=timestamp_ms,
151+
headers=_headers,
140152
transaction=transaction,
141153
)
142154
# Use window start timestamp as a new record timestamp
155+
# Yield both updated and triggered windows
143156
for key, window in updated_windows:
144157
yield window, key, window["start"], None
158+
for key, window in triggered_windows:
159+
yield window, key, window["start"], None
145160

146161
def on_watermark(
147162
_value: Any,
@@ -169,11 +184,22 @@ def process_window(
169184
value: Any,
170185
key: Any,
171186
timestamp_ms: int,
187+
headers: Any,
172188
transaction: WindowedPartitionTransaction,
173-
) -> Iterable[WindowKeyResult]:
189+
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
190+
"""
191+
Process a window update for the given value and key.
192+
193+
Returns:
194+
A tuple of (updated_windows, triggered_windows) where:
195+
- updated_windows: Windows that were updated but not expired
196+
- triggered_windows: Windows that were expired early due to before_update/after_update callbacks
197+
"""
174198
state = transaction.as_state(prefix=key)
175199
duration_ms = self._duration_ms
176200
grace_ms = self._grace_ms
201+
before_update = self._before_update
202+
after_update = self._after_update
177203

178204
collect = self.collect
179205
aggregate = self.aggregate
@@ -190,6 +216,7 @@ def process_window(
190216
max_expired_window_end = latest_timestamp - grace_ms
191217
max_expired_window_start = max_expired_window_end - duration_ms
192218
updated_windows: list[WindowKeyResult] = []
219+
triggered_windows: list[WindowKeyResult] = []
193220
for start, end in ranges:
194221
if start <= max_expired_window_start:
195222
late_by_ms = max_expired_window_end - timestamp_ms
@@ -207,18 +234,78 @@ def process_window(
207234
# since actual values are stored separately and combined into an array
208235
# during window expiration.
209236
aggregated = None
237+
210238
if aggregate:
211239
current_value = state.get_window(start, end)
212240
if current_value is None:
213241
current_value = self._initialize_value()
214242

243+
# Check before_update trigger
244+
if before_update and before_update(
245+
current_value, value, key, timestamp_ms, headers
246+
):
247+
# Get collected values for the result
248+
# Do NOT include the current value - before_update means
249+
# we expire BEFORE adding the current value
250+
collected = state.get_from_collection(start, end) if collect else []
251+
252+
result = self._results(current_value, collected, start, end)
253+
triggered_windows.append((key, result))
254+
transaction.delete_window(start, end, prefix=key)
255+
# Note: We don't delete from collection here - normal expiration
256+
# will handle cleanup for both tumbling and hopping windows
257+
continue
258+
215259
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
216-
updated_windows.append(
217-
(
218-
key,
219-
self._results(aggregated, [], start, end),
220-
)
221-
)
260+
261+
# Check after_update trigger
262+
if after_update and after_update(
263+
aggregated, value, key, timestamp_ms, headers
264+
):
265+
# Get collected values for the result
266+
collected = []
267+
if collect:
268+
collected = state.get_from_collection(start, end)
269+
# Add the current value that's being collected
270+
collected.append(self._collect_value(value))
271+
272+
result = self._results(aggregated, collected, start, end)
273+
triggered_windows.append((key, result))
274+
transaction.delete_window(start, end, prefix=key)
275+
# Note: We don't delete from collection here - normal expiration
276+
# will handle cleanup for both tumbling and hopping windows
277+
continue
278+
279+
result = self._results(aggregated, [], start, end)
280+
updated_windows.append((key, result))
281+
elif collect and (before_update or after_update):
282+
# For collect-only windows, get the old collected values
283+
old_collected = state.get_from_collection(start, end)
284+
285+
# Check before_update trigger (before adding new value)
286+
if before_update and before_update(
287+
old_collected, value, key, timestamp_ms, headers
288+
):
289+
# Expire with the current collection (WITHOUT the new value)
290+
result = self._results(None, old_collected, start, end)
291+
triggered_windows.append((key, result))
292+
transaction.delete_window(start, end, prefix=key)
293+
# Note: We don't delete from collection here - normal expiration
294+
# will handle cleanup for both tumbling and hopping windows
295+
continue
296+
297+
# Check after_update trigger (conceptually after adding new value)
298+
# For collect, "after update" means after the value would be added
299+
if after_update:
300+
new_collected = [*old_collected, self._collect_value(value)]
301+
if after_update(new_collected, value, key, timestamp_ms, headers):
302+
result = self._results(None, new_collected, start, end)
303+
triggered_windows.append((key, result))
304+
transaction.delete_window(start, end, prefix=key)
305+
# Note: We don't delete from collection here - normal expiration
306+
# will handle cleanup for both tumbling and hopping windows
307+
continue
308+
222309
state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms)
223310

224311
if collect:
@@ -227,7 +314,7 @@ def process_window(
227314
id=timestamp_ms,
228315
)
229316

230-
return updated_windows
317+
return updated_windows, triggered_windows
231318

232319
def expire_by_partition(
233320
self,

0 commit comments

Comments
 (0)