|
1 | 1 | import asyncio |
2 | 2 | import inspect |
| 3 | +import signal |
3 | 4 | from concurrent.futures import Executor |
4 | 5 | from logging import getLogger |
5 | 6 | from time import time |
@@ -334,6 +335,12 @@ async def listen(self) -> None: # pragma: no cover |
334 | 335 | gr.start_soon(self.prefetcher, queue) |
335 | 336 | gr.start_soon(self.runner, queue) |
336 | 337 |
|
| 338 | + # Propagate cancellation to the prefetcher & runner |
| 339 | + def _cancel(*_: Any) -> None: |
| 340 | + gr.cancel_scope.cancel() |
| 341 | + |
| 342 | + signal.signal(signal.SIGINT, _cancel) |
| 343 | + |
337 | 344 | if self.on_exit is not None: |
338 | 345 | self.on_exit(self) |
339 | 346 |
|
@@ -361,9 +368,7 @@ async def prefetcher( |
361 | 368 | message = await iterator.__anext__() |
362 | 369 | fetched_tasks += 1 |
363 | 370 | await queue.put(message) |
364 | | - except asyncio.CancelledError: |
365 | | - break |
366 | | - except StopAsyncIteration: |
| 371 | + except (asyncio.CancelledError, StopAsyncIteration): |
367 | 372 | break |
368 | 373 |
|
369 | 374 | await queue.put(QUEUE_DONE) |
@@ -394,31 +399,35 @@ def task_cb(task: "asyncio.Task[Any]") -> None: |
394 | 399 | self.sem.release() |
395 | 400 |
|
396 | 401 | while True: |
397 | | - # Waits for semaphore to be released. |
398 | | - if self.sem is not None: |
399 | | - await self.sem.acquire() |
400 | | - |
401 | | - self.sem_prefetch.release() |
402 | | - message = await queue.get() |
403 | | - if message is QUEUE_DONE: |
404 | | - # asyncio.wait will throw an error if there is nothing to wait for |
405 | | - if tasks: |
406 | | - logger.info("Waiting for running tasks to complete.") |
407 | | - await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) |
408 | | - break |
| 402 | + try: |
| 403 | + # Waits for semaphore to be released. |
| 404 | + if self.sem is not None: |
| 405 | + await self.sem.acquire() |
| 406 | + |
| 407 | + self.sem_prefetch.release() |
| 408 | + message = await queue.get() |
| 409 | + if message is QUEUE_DONE: |
| 410 | + # asyncio.wait will throw an error if there is nothing to wait for |
| 411 | + if tasks: |
| 412 | + logger.info("Waiting for running tasks to complete.") |
| 413 | + await asyncio.wait(tasks, timeout=self.wait_tasks_timeout) |
| 414 | + break |
409 | 415 |
|
410 | | - task = asyncio.create_task( |
411 | | - self.callback(message=message, raise_err=False), |
412 | | - ) |
413 | | - tasks.add(task) |
414 | | - |
415 | | - # We want the task to remove itself from the set when it's done. |
416 | | - # |
417 | | - # Because if we won't save it anywhere, |
418 | | - # python's GC can silently cancel task |
419 | | - # and this behaviour considered to be a Hisenbug. |
420 | | - # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ |
421 | | - task.add_done_callback(task_cb) |
| 416 | + task = asyncio.create_task( |
| 417 | + self.callback(message=message, raise_err=False), |
| 418 | + ) |
| 419 | + tasks.add(task) |
| 420 | + |
| 421 | + # We want the task to remove itself from the set when it's done. |
| 422 | + # |
| 423 | + # Because if we won't save it anywhere, |
| 424 | + # python's GC can silently cancel task |
| 425 | + # and this behaviour considered to be a Hisenbug. |
| 426 | + # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ |
| 427 | + task.add_done_callback(task_cb) |
| 428 | + |
| 429 | + except asyncio.CancelledError: |
| 430 | + break |
422 | 431 |
|
423 | 432 | def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None: |
424 | 433 | """ |
|
0 commit comments