11import asyncio
22import inspect
33from concurrent .futures import Executor
4+ from contextlib import asynccontextmanager
45from logging import getLogger
56from time import time
6- from typing import Any , Callable , Dict , List , Optional , Set , Union , get_type_hints
7+ from typing import (
8+ Any ,
9+ AsyncIterator ,
10+ Callable ,
11+ Dict ,
12+ List ,
13+ Optional ,
14+ Set ,
15+ Union ,
16+ get_type_hints ,
17+ )
718
819import anyio
920from taskiq_dependencies import DependencyGraph
2132
2233logger = getLogger (__name__ )
2334QUEUE_DONE = b"-1"
35+ QUEUE_SKIP = b"-2"
2436
2537
2638def _run_sync (
@@ -83,6 +95,11 @@ def __init__(
8395 "can result in undefined behavior" ,
8496 )
8597 self .sem_prefetch = asyncio .Semaphore (max_prefetch )
98+ self .idle_tasks : "Set[asyncio.Task[Any]]" = set ()
99+ self .sem_lock : asyncio .Lock = asyncio .Lock ()
100+ self .listen_queue : "asyncio.Queue[Union[AckableMessage, bytes]]" = (
101+ asyncio .Queue ()
102+ )
86103
87104 async def callback ( # noqa: C901, PLR0912
88105 self ,
@@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
227244 broker_ctx = self .broker .custom_dependency_context
228245 broker_ctx .update (
229246 {
230- Context : Context (message , self .broker ),
247+ Context : Context (message , self .broker , self . idle ),
231248 TaskiqState : self .broker .state ,
232249 },
233250 )
@@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover
329346 await self .broker .startup ()
330347 logger .info ("Listening started." )
331348 queue : "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio .Queue ()
349+ self .listen_queue = queue
332350
333351 async with anyio .create_task_group () as gr :
334352 gr .start_soon (self .prefetcher , queue )
@@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
396414 while True :
397415 # Waits for semaphore to be released.
398416 if self .sem is not None :
399- await self .sem .acquire ()
417+ async with self .sem_lock :
418+ await self .sem .acquire ()
400419
401420 self .sem_prefetch .release ()
402421 message = await queue .get ()
@@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
407426 await asyncio .wait (tasks , timeout = self .wait_tasks_timeout )
408427 break
409428
429+ if message is QUEUE_SKIP :
430+ if self .sem is not None :
431+ self .sem .release ()
432+ continue
433+
410434 task = asyncio .create_task (
411435 self .callback (message = message , raise_err = False ),
412436 )
@@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
420444 # https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
421445 task .add_done_callback (task_cb )
422446
447+ @asynccontextmanager
448+ async def idle (self , timeout : Optional [int ] = None ) -> AsyncIterator [None ]:
449+ """Idle task.
450+
451+ :param timeout: idle time
452+ """
453+ if self .sem is not None :
454+ self .sem .release ()
455+
456+ def acquire () -> "asyncio.Task[Any]" :
457+ if self .sem is None :
458+ raise ValueError (self .sem )
459+
460+ task = asyncio .create_task (self .sem .acquire ())
461+ task .add_done_callback (self .idle_tasks .discard )
462+ self .idle_tasks .add (task )
463+ return task
464+
465+ cancelled = False
466+ try :
467+ with anyio .fail_after (timeout ):
468+ yield
469+ except asyncio .CancelledError :
470+ if self .sem :
471+ acquire ()
472+
473+ cancelled = True
474+ raise
475+
476+ finally :
477+ if not cancelled and self .sem is not None :
478+ try :
479+ await self .sem_lock .acquire ()
480+ except asyncio .CancelledError :
481+ acquire ()
482+ raise
483+
484+ try :
485+ self .listen_queue .put_nowait (QUEUE_SKIP )
486+ await acquire ()
487+ finally :
488+ self .sem_lock .release ()
489+
423490 def _prepare_task (self , name : str , handler : Callable [..., Any ]) -> None :
424491 """
425492 Prepare task for execution.
0 commit comments