Skip to content

Commit 57cab52

Browse files
committed
refactor: use async iterator instead of event
Replicates graphql/graphql-js@95bf842
1 parent bef2723 commit 57cab52

File tree

2 files changed

+68
-50
lines changed

2 files changed

+68
-50
lines changed

src/graphql/execution/incremental_graph.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from __future__ import annotations
44

5-
from asyncio import Event, Task, ensure_future
5+
from asyncio import CancelledError, Future, Task, ensure_future
66
from typing import (
77
TYPE_CHECKING,
88
Any,
9+
AsyncGenerator,
910
Awaitable,
10-
Iterator,
11+
Generator,
12+
Iterable,
1113
Sequence,
1214
cast,
1315
)
@@ -42,17 +44,17 @@ class IncrementalGraph:
4244

4345
_pending: dict[SubsequentResultRecord, None]
4446
_new_pending: dict[SubsequentResultRecord, None]
45-
_completed_result_queue: list[IncrementalDataRecordResult]
47+
_completed_queue: list[IncrementalDataRecordResult]
48+
_next_queue: list[Future[Iterable[IncrementalDataRecordResult]]]
4649

47-
_resolve: Event | None
48-
_tasks: set[Task[Any]]
50+
_tasks: set[Task[Any]] # benutzt????
4951

5052
def __init__(self) -> None:
5153
"""Initialize the IncrementalGraph."""
5254
self._pending = {}
5355
self._new_pending = {}
54-
self._completed_result_queue = []
55-
self._resolve = None # lazy initialization
56+
self._completed_queue = []
57+
self._next_queue = []
5658
self._tasks = set()
5759

5860
def add_incremental_data_records(
@@ -95,11 +97,11 @@ async def enqueue_deferred(
9597
async def enqueue_stream(
9698
stream_result: Awaitable[StreamItemsResult],
9799
) -> None:
98-
self._enqueue_completed_stream_items(await stream_result)
100+
self._enqueue(await stream_result)
99101

100102
self._add_task(enqueue_stream(stream_result))
101103
else:
102-
self._enqueue_completed_stream_items(stream_result) # type: ignore
104+
self._enqueue(stream_result) # type: ignore
103105

104106
def get_new_pending(self) -> list[SubsequentResultRecord]:
105107
"""Get new pending subsequent result records."""
@@ -122,12 +124,21 @@ def get_new_pending(self) -> list[SubsequentResultRecord]:
122124
self._new_pending.clear()
123125
return new_pending
124126

125-
def completed_results(self) -> Iterator[IncrementalDataRecordResult]:
126-
"""Yield completed incremental data record results."""
127-
queue = self._completed_result_queue
128-
while queue:
129-
completed_result = queue.pop(0)
130-
yield completed_result
127+
async def completed_incremental_data(
128+
self,
129+
) -> AsyncGenerator[Iterable[IncrementalDataRecordResult], None]:
130+
"""Asynchronously yield completed incremental data record results."""
131+
while True:
132+
if self._completed_queue:
133+
first_result = self._completed_queue.pop(0)
134+
yield self._yield_current_completed_incremental_data(first_result)
135+
else:
136+
future: Future[Iterable[IncrementalDataRecordResult]] = Future()
137+
self._next_queue.append(future)
138+
try:
139+
yield await future
140+
except CancelledError:
141+
break # pragma: no cover
131142

132143
def has_next(self) -> bool:
133144
"""Check if there are more results to process."""
@@ -143,12 +154,12 @@ def complete_deferred_fragment(
143154
reconcilable_results
144155
):
145156
return None
146-
del self._pending[deferred_fragment_record]
157+
self.remove_subsequent_result_record(deferred_fragment_record)
147158
new_pending = self._new_pending
148-
extend = self._completed_result_queue.extend
149159
for child in deferred_fragment_record.children:
150160
new_pending[child] = None
151-
extend(child.results)
161+
for result in child.results:
162+
self._enqueue(result)
152163
return reconcilable_results
153164

154165
def remove_subsequent_result_record(
@@ -157,6 +168,8 @@ def remove_subsequent_result_record(
157168
) -> None:
158169
"""Remove a subsequent result record as no longer pending."""
159170
del self._pending[subsequent_result_record]
171+
if not self._pending:
172+
self.stop_incremental_data()
160173

161174
def _add_deferred_fragment_record(
162175
self, deferred_fragment_record: DeferredFragmentRecord
@@ -196,33 +209,35 @@ def _enqueue_completed_deferred_grouped_field_set(
196209
if deferred_fragment_record.id is not None:
197210
has_pending_parent = True
198211
deferred_fragment_record.results.append(result)
199-
append = self._completed_result_queue.append
200212
if has_pending_parent:
201-
append(result)
202-
self._trigger()
203-
204-
def _enqueue_completed_stream_items(self, result: StreamItemsResult) -> None:
205-
"""Enqueue completed stream items result."""
206-
self._completed_result_queue.append(result)
207-
self._trigger()
208-
209-
def _trigger(self) -> None:
210-
"""Trigger the resolve event."""
211-
resolve = self._resolve
212-
if resolve is not None:
213-
resolve.set()
214-
self._resolve = Event()
215-
216-
async def new_completed_result_available(self) -> None:
217-
"""Get an awaitable that resolves when a new completed result is available."""
218-
resolve = self._resolve
219-
if resolve is None:
220-
self._resolve = resolve = Event()
221-
await resolve.wait()
213+
self._enqueue(result)
222214

223215
def _add_task(self, awaitable: Awaitable[Any]) -> None:
224216
"""Add the given task to the tasks set for later execution."""
225217
tasks = self._tasks
226218
task = ensure_future(awaitable)
227219
tasks.add(task)
228220
task.add_done_callback(tasks.discard)
221+
222+
def stop_incremental_data(self) -> None:
223+
"""Stop the delivery of inclremental data."""
224+
for future in self._next_queue:
225+
future.cancel() # pragma: no cover
226+
227+
def _yield_current_completed_incremental_data(
228+
self, first_result: IncrementalDataRecordResult
229+
) -> Generator[IncrementalDataRecordResult, None, None]:
230+
"""Yield the current completed incremental data."""
231+
yield first_result
232+
queue = self._completed_queue
233+
while queue:
234+
yield queue.pop(0)
235+
236+
def _enqueue(self, completed: IncrementalDataRecordResult) -> None:
237+
"""Enqueue completed incremental data record result."""
238+
try:
239+
future = self._next_queue.pop(0)
240+
except IndexError:
241+
self._completed_queue.append(completed)
242+
else:
243+
future.set_result(self._yield_current_completed_incremental_data(completed))

src/graphql/execution/incremental_publisher.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,22 @@ async def _subscribe(
133133
"""Subscribe to the incremental results."""
134134
try:
135135
incremental_graph = self._incremental_graph
136-
completed_results = incremental_graph.completed_results
137136
get_new_pending = incremental_graph.get_new_pending
138-
new_result_available = incremental_graph.new_completed_result_available
139137
check_has_next = incremental_graph.has_next
140138
pending_sources_to_results = self._pending_sources_to_results
139+
completed_incremental_data = incremental_graph.completed_incremental_data()
140+
# use the raw iterator rather than 'async for' so as not to end the iterator
141+
# when exiting the loop with the next value
142+
get_next_results = completed_incremental_data.__aiter__().__anext__
141143
is_done = False
142144
while not is_done:
145+
try:
146+
completed_results = await get_next_results()
147+
except StopAsyncIteration: # pragma: no cover
148+
break
143149
pending: list[PendingResult] = []
144150

145-
for completed_result in completed_results():
151+
for completed_result in completed_results:
146152
if is_deferred_grouped_field_set_result(completed_result):
147153
self._handle_completed_deferred_grouped_field_set(
148154
completed_result
@@ -173,15 +179,12 @@ async def _subscribe(
173179
self._completed = []
174180

175181
yield subsequent_incremental_execution_result
176-
177-
else:
178-
await new_result_available()
179-
180182
finally:
181-
await self._return_stream_iterators()
183+
await self._stop_async_iterators()
182184

183-
async def _return_stream_iterators(self) -> None:
184-
"""Finish all stream iterators."""
185+
async def _stop_async_iterators(self) -> None:
186+
"""Finish all async iterators."""
187+
self._incremental_graph.stop_incremental_data()
185188
cancellable_streams = self._context.cancellable_streams
186189
if cancellable_streams is None:
187190
return

0 commit comments

Comments
 (0)