@@ -56,6 +56,8 @@ def __init__(
5656 run_starup : bool = True ,
5757 ack_type : Optional [AcknowledgeType ] = None ,
5858 on_exit : Optional [Callable [["Receiver" ], None ]] = None ,
59+ max_tasks_to_execute : Optional [int ] = None ,
60+ wait_tasks_timeout : Optional [float ] = None ,
5961 ) -> None :
6062 self .broker = broker
6163 self .executor = executor
@@ -68,6 +70,8 @@ def __init__(
6870 self .on_exit = on_exit
6971 self .ack_time = ack_type or AcknowledgeType .WHEN_SAVED
7072 self .known_tasks : Set [str ] = set ()
73+ self .max_tasks_to_execute = max_tasks_to_execute
74+ self .wait_tasks_timeout = wait_tasks_timeout
7175 for task in self .broker .get_all_tasks ().values ():
7276 self ._prepare_task (task .task_name , task .original_func )
7377 self .sem : "Optional[asyncio.Semaphore]" = None
@@ -342,12 +346,20 @@ async def prefetcher(
342346
343347 :param queue: queue for prefetched data.
344348 """
349+ fetched_tasks : int = 0
345350 iterator = self .broker .listen ()
346351
347352 while True :
348353 try :
349354 await self .sem_prefetch .acquire ()
355+ if (
356+ self .max_tasks_to_execute
357+ and fetched_tasks >= self .max_tasks_to_execute
358+ ):
359+ logger .info ("Max number of tasks executed." )
360+ break
350361 message = await iterator .__anext__ ()
362+ fetched_tasks += 1
351363 await queue .put (message )
352364 except asyncio .CancelledError :
353365 break
@@ -389,6 +401,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
389401 self .sem_prefetch .release ()
390402 message = await queue .get ()
391403 if message is QUEUE_DONE :
404+ logger .info ("Waiting for running tasks to complete." )
405+ await asyncio .wait (tasks , timeout = self .wait_tasks_timeout )
392406 break
393407
394408 task = asyncio .create_task (
0 commit comments