33from concurrent .futures import Executor
44from logging import getLogger
55from time import time
6- from typing import Any , Callable , Dict , Optional , get_type_hints
6+ from typing import Any , Callable , Dict , Optional , Set , get_type_hints
77
8+ import anyio
89from taskiq_dependencies import DependencyGraph
910
1011from taskiq .abc .broker import AsyncBroker
1718from taskiq .utils import maybe_awaitable
1819
1920logger = getLogger (__name__ )
21+ QUEUE_DONE = b"-1"
2022
2123
2224def _run_sync (target : Callable [..., Any ], message : TaskiqMessage ) -> Any :
@@ -36,12 +38,13 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
3638class Receiver :
3739 """Class that uses as a callback handler."""
3840
39- def __init__ (
41+ def __init__ ( # noqa: WPS211
4042 self ,
4143 broker : AsyncBroker ,
4244 executor : Optional [Executor ] = None ,
4345 validate_params : bool = True ,
4446 max_async_tasks : "Optional[int]" = None ,
47+ max_prefetch : int = 0 ,
4548 ) -> None :
4649 self .broker = broker
4750 self .executor = executor
@@ -61,6 +64,7 @@ def __init__(
6164 "Setting unlimited number of async tasks "
6265 + "can result in undefined behavior" ,
6366 )
67+ self .sem_prefetch = asyncio .Semaphore (max_prefetch )
6468
6569 async def callback ( # noqa: C901, WPS213
6670 self ,
@@ -239,7 +243,38 @@ async def listen(self) -> None: # pragma: no cover
239243 """
240244 await self .broker .startup ()
241245 logger .info ("Listening started." )
242- tasks = set ()
246+ queue : asyncio .Queue [bytes ] = asyncio .Queue ()
247+
248+ async with anyio .create_task_group () as gr :
249+ gr .start_soon (self .prefetcher , queue )
250+ gr .start_soon (self .runner , queue )
251+
252+ async def prefetcher (self , queue : "asyncio.Queue[Any]" ) -> None :
253+ """
254+ Prefetch tasks data.
255+
256+ :param queue: queue for prefetched data.
257+ """
258+ iterator = self .broker .listen ()
259+
260+ while True :
261+ try :
262+ await self .sem_prefetch .acquire ()
263+ message = await iterator .__anext__ () # noqa: WPS609
264+ await queue .put (message )
265+
266+ except StopAsyncIteration :
267+ break
268+
269+ await queue .put (QUEUE_DONE )
270+
271+ async def runner (self , queue : "asyncio.Queue[bytes]" ) -> None :
272+ """
273+ Run tasks.
274+
275+ :param queue: queue with prefetched data.
276+ """
277+ tasks : Set [asyncio .Task [Any ]] = set ()
243278
244279 def task_cb (task : "asyncio.Task[Any]" ) -> None :
245280 """
@@ -255,11 +290,19 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
255290 if self .sem is not None :
256291 self .sem .release ()
257292
258- async for message in self . broker . listen () :
293+ while True :
259294 # Waits for semaphore to be released.
260295 if self .sem is not None :
261296 await self .sem .acquire ()
262- task = asyncio .create_task (self .callback (message = message , raise_err = False ))
297+
298+ self .sem_prefetch .release ()
299+ message = await queue .get ()
300+ if message is QUEUE_DONE :
301+ break
302+
303+ task = asyncio .create_task (
304+ self .callback (message = message , raise_err = False ),
305+ )
263306 tasks .add (task )
264307
265308 # We want the task to remove itself from the set when it's done.
0 commit comments