11import pickle
2+ from abc import abstractmethod
23from logging import getLogger
34from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
45
1213logger = getLogger ("taskiq.redis_broker" )
1314
1415
15- class RedisBroker (AsyncBroker ):
16- """Broker that works with Redis."""
16+ class BaseRedisBroker (AsyncBroker ):
17+ """Base broker that works with Redis."""
1718
1819 def __init__ (
1920 self ,
@@ -44,31 +45,12 @@ def __init__(
4445 max_connections = max_connection_pool_size ,
4546 ** connection_kwargs ,
4647 )
47-
48- self .redis_pubsub_channel = queue_name
48+ self .queue_name = queue_name
4949
5050 async def shutdown (self ) -> None :
5151 """Closes redis connection pool."""
5252 await self .connection_pool .disconnect ()
5353
54- async def kick (self , message : BrokerMessage ) -> None :
55- """
56- Sends a message to the redis broker list.
57-
58- This function constructs message for redis
59- and sends it.
60-
61- The message is pickled dict object with message,
62- task_id, task_name and labels.
63-
64- :param message: message to send.
65- """
66- async with Redis (connection_pool = self .connection_pool ) as redis_conn :
67- await redis_conn .publish (
68- self .redis_pubsub_channel ,
69- pickle .dumps (message ),
70- )
71-
7254 async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
7355 """
7456 Listen redis queue for new messages.
@@ -78,24 +60,60 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
7860
7961 :yields: broker messages.
8062 """
63+ async for message in self ._listen_to_raw_messages ():
64+ try :
65+ redis_message = pickle .loads (message )
66+ if isinstance (redis_message , BrokerMessage ):
67+ yield redis_message
68+ except (
69+ TypeError ,
70+ AttributeError ,
71+ pickle .UnpicklingError ,
72+ ) as exc :
73+ logger .debug (
74+ "Cannot read broker message %s" ,
75+ exc ,
76+ exc_info = True ,
77+ )
78+
79+ @abstractmethod
80+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
81+ """
82+ Generator for reading raw data from Redis.
83+
84+ :yields: raw data.
85+ """
86+ yield # type: ignore
87+
88+
89+ class PubSubBroker (BaseRedisBroker ):
90+ """Broker that works with Redis and broadcasts tasks to all workers."""
91+
92+ async def kick (self , message : BrokerMessage ) -> None : # noqa: D102
93+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
94+ await redis_conn .publish (self .queue_name , pickle .dumps (message ))
95+
96+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
8197 async with Redis (connection_pool = self .connection_pool ) as redis_conn :
8298 redis_pubsub_channel = redis_conn .pubsub ()
83- await redis_pubsub_channel .subscribe (self .redis_pubsub_channel )
99+ await redis_pubsub_channel .subscribe (self .queue_name )
84100 async for message in redis_pubsub_channel .listen ():
85- if message :
86- try :
87- redis_message = pickle .loads (
88- message ["data" ],
89- )
90- if isinstance (redis_message , BrokerMessage ):
91- yield redis_message
92- except (
93- TypeError ,
94- AttributeError ,
95- pickle .UnpicklingError ,
96- ) as exc :
97- logger .debug (
98- "Cannot read broker message %s" ,
99- exc ,
100- exc_info = True ,
101- )
101+ if not message :
102+ continue
103+ yield message ["data" ]
104+
105+
106+ class ListQueueBroker (BaseRedisBroker ):
107+ """Broker that works with Redis and distributes tasks between workers."""
108+
109+ async def kick (self , message : BrokerMessage ) -> None : # noqa: D102
110+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
111+ await redis_conn .lpush (self .queue_name , pickle .dumps (message ))
112+
113+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
114+ redis_brpop_data_position = 1
115+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
116+ while True : # noqa: WPS457
117+ yield (await redis_conn .brpop (self .queue_name ))[
118+ redis_brpop_data_position
119+ ]
0 commit comments