Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 114 additions & 210 deletions cadence/_internal/workflow/decision_events_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
particularly focusing on decision-related events for replay and execution.
"""

from dataclasses import dataclass, field
from typing import List, Optional
from dataclasses import dataclass
from typing import Iterator, List, Optional

from cadence._internal.workflow.history_event_iterator import HistoryEventsIterator
from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse

Expand All @@ -19,33 +20,21 @@ class DecisionEvents:
Represents events for a single decision iteration.
"""

events: List[HistoryEvent] = field(default_factory=list)
markers: List[HistoryEvent] = field(default_factory=list)
replay: bool = False
replay_current_time_milliseconds: Optional[int] = None
next_decision_event_id: Optional[int] = None
input: List[HistoryEvent]
output: List[HistoryEvent]
markers: List[HistoryEvent]
replay: bool
replay_current_time_milliseconds: int
next_decision_event_id: int

def get_events(self) -> List[HistoryEvent]:
"""Return all events in this decision iteration."""
return self.events

def get_markers(self) -> List[HistoryEvent]:
"""Return marker events."""
return self.markers

def is_replay(self) -> bool:
"""Check if this decision is in replay mode."""
return self.replay

def get_event_by_id(self, event_id: int) -> Optional[HistoryEvent]:
"""Retrieve a specific event by ID, returns None if not found."""
for event in self.events:
def get_output_event_by_id(self, event_id: int) -> Optional[HistoryEvent]:
for event in self.input:
if hasattr(event, "event_id") and event.event_id == event_id:
return event
return None


class DecisionEventsIterator:
class DecisionEventsIterator(Iterator[DecisionEvents]):
"""
Iterator for processing decision events from workflow history.

Expand All @@ -54,207 +43,122 @@ class DecisionEventsIterator:
"""

def __init__(
self, decision_task: PollForDecisionTaskResponse, events: List[HistoryEvent]
self,
decision_task: PollForDecisionTaskResponse,
events: List[HistoryEvent],
):
self._decision_task = decision_task
self._events: List[HistoryEvent] = events
self._decision_task_started_event: Optional[HistoryEvent] = None
self._next_decision_event_id = 1
self._replay = True
self._events: HistoryEventsIterator = HistoryEventsIterator(events)
self._next_decision_event_id: Optional[int] = None
self._replay_current_time_milliseconds: Optional[int] = None

self._event_index = 0
# Find first decision task started event
for i, event in enumerate(self._events):
if _is_decision_task_started(event):
self._event_index = i
break

async def has_next_decision_events(self) -> bool:
# Look for the next DecisionTaskStarted event from current position
for i in range(self._event_index, len(self._events)):
if _is_decision_task_started(self._events[i]):
return True

return False

async def next_decision_events(self) -> DecisionEvents:
# Find next DecisionTaskStarted event
start_index = None
for i in range(self._event_index, len(self._events)):
if _is_decision_task_started(self._events[i]):
start_index = i
break

if start_index is None:
raise StopIteration("No more decision events")

decision_events = DecisionEvents()
decision_events.replay = self._replay
decision_events.replay_current_time_milliseconds = (
self._replay_current_time_milliseconds
)
decision_events.next_decision_event_id = self._next_decision_event_id

# Process DecisionTaskStarted event
decision_task_started = self._events[start_index]
self._decision_task_started_event = decision_task_started
decision_events.events.append(decision_task_started)

# Update replay time if available
if decision_task_started.event_time:
self._replay_current_time_milliseconds = (
decision_task_started.event_time.seconds * 1000
)
decision_events.replay_current_time_milliseconds = (
self._replay_current_time_milliseconds
)

# Process subsequent events until we find the corresponding DecisionTask completion
current_index = start_index + 1
while current_index < len(self._events):
event = self._events[current_index]
decision_events.events.append(event)

# Categorize the event
if _is_marker_recorded(event):
decision_events.markers.append(event)
elif _is_decision_task_completion(event):
# This marks the end of this decision iteration
self._process_decision_completion_event(event, decision_events)
current_index += 1 # Move past this event
break

current_index += 1

# Update the event index for next iteration
self._event_index = current_index

# Update the next decision event ID
if decision_events.events:
last_event = decision_events.events[-1]
if hasattr(last_event, "event_id"):
self._next_decision_event_id = last_event.event_id + 1
def __iter__(self):
return self

# Check if this is the last decision events
# Set replay to false only if there are no more decision events after this one
# Check directly without calling has_next_decision_events to avoid recursion
has_more = False
for i in range(self._event_index, len(self._events)):
if _is_decision_task_started(self._events[i]):
has_more = True
def __next__(self) -> DecisionEvents:
"""
Process the next decision batch.
1. Find the next valid decision task started event during replay or last scheduled decision task events for non-replay
2. Collect the decision input events before the decision task
3. Collect the decision output events after the decision task

Relay mode is determined by checking if the decision task is completed or not
"""
decision_input_events: List[HistoryEvent] = []
decision_output_events: List[HistoryEvent] = []
decision_event: Optional[HistoryEvent] = None
for event in self._events:
match event.WhichOneof("attributes"):
case "decision_task_started_event_attributes":
next_event = self._events.peek()

# latest event, not replay, assign started event as decision event insteaad
if next_event is None:
decision_event = event
break

match next_event.WhichOneof("attributes"):
case (
"decision_task_failed_event_attributes"
| "decision_task_timed_out_event_attributes"
):
# skip failed / timed out decision tasks and continue searching
next(self._events)
continue
case "decision_task_completed_event_attributes":
# found decision task completed event, stop
decision_event = next(self._events)
break
case _:
raise ValueError(
f"unexpected event type after decision task started event: {next_event}"
)

case _:
decision_input_events.append(event)

if not decision_event:
raise StopIteration("no decision event found")

# collect decision output events
while self._events.has_next():
nxt = self._events.peek() if self._events.has_next() else None
if nxt and not is_decision_event(nxt):
break
decision_output_events.append(next(self._events))

if not has_more:
self._replay = False
decision_events.replay = False

return decision_events
replay_current_time_milliseconds = decision_event.event_time.ToMilliseconds()

def _process_decision_completion_event(
self, event: HistoryEvent, decision_events: DecisionEvents
):
"""Process the decision completion event and update state."""

# Check if we're still in replay mode
# This is determined by comparing event IDs with the current decision task's started event ID
replay: bool
next_decision_event_id: int
if (
self._decision_task_started_event
and hasattr(self._decision_task_started_event, "event_id")
and hasattr(event, "event_id")
):
# If this completion event ID is >= the current decision task's started event ID,
# we're no longer in replay mode
current_task_started_id = (
getattr(self._decision_task.started_event_id, "value", 0)
if hasattr(self._decision_task, "started_event_id")
else 0
)

if event.event_id >= current_task_started_id:
self._replay = False
decision_events.replay = False

def get_replay_current_time_milliseconds(self) -> Optional[int]:
"""Get the current replay time in milliseconds."""
return self._replay_current_time_milliseconds

def is_replay_mode(self) -> bool:
"""Check if the iterator is currently in replay mode."""
return self._replay

def __aiter__(self):
return self

async def __anext__(self) -> DecisionEvents:
if not await self.has_next_decision_events():
raise StopAsyncIteration
return await self.next_decision_events()
decision_event.WhichOneof("attributes")
== "decision_task_completed_event_attributes"
): # completed decision task
replay = True
next_decision_event_id = decision_event.event_id + 1
else:
replay = False
next_decision_event_id = decision_event.event_id + 2

# collect marker events
markers = [m for m in decision_output_events if is_marker_event(m)]

return DecisionEvents(
input=decision_input_events,
output=decision_output_events,
markers=markers,
replay=replay,
replay_current_time_milliseconds=replay_current_time_milliseconds,
next_decision_event_id=next_decision_event_id,
)


# Utility functions
def is_decision_event(event: HistoryEvent) -> bool:
"""Check if an event is a decision-related event."""
return (
_is_decision_task_started(event)
or _is_decision_task_completed(event)
or _is_decision_task_failed(event)
or _is_decision_task_timed_out(event)
"""Check if an event is a decision output event."""
return event is not None and event.WhichOneof("attributes") in set(
[
"activity_task_scheduled_event_attributes",
"start_child_workflow_execution_initiated_event_attributes",
"timer_started_event_attributes",
"workflow_execution_completed_event_attributes",
"workflow_execution_failed_event_attributes",
"workflow_execution_canceled_event_attributes",
"workflow_execution_continued_as_new_event_attributes",
"activity_task_cancel_requested_event_attributes",
"request_cancel_activity_task_failed_event_attributes",
"timer_canceled_event_attributes",
"cancel_timer_failed_event_attributes",
"request_cancel_external_workflow_execution_initiated_event_attributes",
"marker_recorded_event_attributes",
"signal_external_workflow_execution_initiated_event_attributes",
"upsert_workflow_search_attributes_event_attributes",
]
)


def is_marker_event(event: HistoryEvent) -> bool:
"""Check if an event is a marker event."""
return _is_marker_recorded(event)


def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]:
"""Extract timestamp from an event in milliseconds."""
if hasattr(event, "event_time") and event.HasField("event_time"):
seconds = getattr(event.event_time, "seconds", 0)
return seconds * 1000 if seconds > 0 else None
return None


def _is_decision_task_started(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskStarted."""
return hasattr(event, "decision_task_started_event_attributes") and event.HasField(
"decision_task_started_event_attributes"
)


def _is_decision_task_completed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskCompleted."""
return hasattr(
event, "decision_task_completed_event_attributes"
) and event.HasField("decision_task_completed_event_attributes")


def _is_decision_task_failed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskFailed."""
return hasattr(event, "decision_task_failed_event_attributes") and event.HasField(
"decision_task_failed_event_attributes"
)


def _is_decision_task_timed_out(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskTimedOut."""
return hasattr(
event, "decision_task_timed_out_event_attributes"
) and event.HasField("decision_task_timed_out_event_attributes")


def _is_marker_recorded(event: HistoryEvent) -> bool:
"""Check if event is MarkerRecorded."""
return hasattr(event, "marker_recorded_event_attributes") and event.HasField(
"marker_recorded_event_attributes"
)


def _is_decision_task_completion(event: HistoryEvent) -> bool:
"""Check if event is any kind of decision task completion."""
return (
_is_decision_task_completed(event)
or _is_decision_task_failed(event)
or _is_decision_task_timed_out(event)
return bool(
event is not None
and event.WhichOneof("attributes") == "marker_recorded_event_attributes"
)
24 changes: 24 additions & 0 deletions cadence/_internal/workflow/history_event_iterator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterator, List, Optional
from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.api.v1.service_workflow_pb2 import (
GetWorkflowExecutionHistoryRequest,
Expand Down Expand Up @@ -32,3 +34,25 @@ async def iterate_history_events(
)
current_page = response.history.events
next_page_token = response.next_page_token


class HistoryEventsIterator(Iterator[HistoryEvent]):
def __init__(self, events: List[HistoryEvent]):
self._iter = iter(events)
self._current = next(self._iter, None)

def __iter__(self):
return self

def __next__(self) -> HistoryEvent:
if not self._current:
raise StopIteration("No more events")
event = self._current
self._current = next(self._iter, None)
return event

def has_next(self) -> bool:
return self._current is not None

def peek(self) -> Optional[HistoryEvent]:
return self._current
Loading