11import asyncio
22from datetime import timedelta
33from logging import getLogger
4- from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar
4+ from typing import Any , AsyncGenerator , Callable , Dict , Optional , TypeVar
55
66from aio_pika import DeliveryMode , ExchangeType , Message , connect_robust
7- from aio_pika .abc import (
8- AbstractChannel ,
9- AbstractIncomingMessage ,
10- AbstractQueue ,
11- AbstractRobustConnection ,
12- )
7+ from aio_pika .abc import AbstractChannel , AbstractQueue , AbstractRobustConnection
138from taskiq .abc .broker import AsyncBroker
149from taskiq .abc .result_backend import AsyncResultBackend
1510from taskiq .message import BrokerMessage
@@ -219,30 +214,42 @@ async def kick(self, message: BrokerMessage) -> None:
219214 routing_key = self ._delay_queue_name ,
220215 )
221216
222- async def listen (
223- self ,
224- callback : Callable [[BrokerMessage ], Coroutine [Any , Any , None ]],
225- ) -> None :
217+ async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]: # noqa: WPS210
226218 """
227219 Listen to queue.
228220
229- This function listens to queue and calls
230- callback on every new message.
221+ This function listens to queue and
222+ yields every new message.
231223
232- :param callback: function to call on new message.
224+ :yields: parsed broker message.
233225 :raises ValueError: if startup wasn't called.
234226 """
235- self .callback = callback
236227 if self .read_channel is None :
237228 raise ValueError ("Call startup before starting listening." )
238229 await self .read_channel .set_qos (prefetch_count = self ._qos )
239230 queue = await self .declare_queues (self .read_channel )
240- await queue .consume (self .process_message )
241- try : # noqa: WPS501
242- # Wait until terminate
243- await asyncio .Future ()
244- finally :
245- await self .shutdown ()
231+ async with queue .iterator () as iterator :
232+ async for message in iterator :
233+ async with message .process ():
234+ headers = {}
235+ for header_name , header_value in message .headers .items ():
236+ headers [header_name ] = str (header_value )
237+ try :
238+ broker_message = BrokerMessage (
239+ task_id = headers .pop ("task_id" ),
240+ task_name = headers .pop ("task_name" ),
241+ message = message .body ,
242+ labels = headers ,
243+ )
244+ except (ValueError , LookupError ) as exc :
245+ logger .warning (
246+ "Cannot read broker message %s" ,
247+ exc ,
248+ exc_info = True ,
249+ )
250+ continue
251+
252+ yield broker_message
246253
247254 async def shutdown (self ) -> None :
248255 """Close all connections on shutdown."""
@@ -255,32 +262,3 @@ async def shutdown(self) -> None:
255262 await self .write_conn .close ()
256263 if self .read_conn :
257264 await self .read_conn .close ()
258-
259- async def process_message (self , message : AbstractIncomingMessage ) -> None :
260- """
261- Process received message.
262-
263- This function parses broker message and
264- calls callback.
265-
266- :param message: received message.
267- """
268- async with message .process ():
269- headers = {}
270- for header_name , header_value in message .headers .items ():
271- headers [header_name ] = str (header_value )
272- try :
273- broker_message = BrokerMessage (
274- task_id = headers .pop ("task_id" ),
275- task_name = headers .pop ("task_name" ),
276- message = message .body ,
277- labels = headers ,
278- )
279- except (ValueError , LookupError ) as exc :
280- logger .warning (
281- "Cannot read broker message %s" ,
282- exc ,
283- exc_info = True ,
284- )
285- return
286- await self .callback (broker_message )
0 commit comments