Skip to content

Commit ef6e781

Browse files
committed
Added plugin to create delay messages.
1 parent 785e9f5 commit ef6e781

File tree

5 files changed

+194
-47
lines changed

5 files changed

+194
-47
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
pytest:
3030
services:
3131
rabbit:
32-
image: rabbitmq:3.9.16-alpine
32+
image: heidiks/rabbitmq-delayed-message-exchange:latest
3333
env:
3434
RABBITMQ_DEFAULT_USER: "guest"
3535
RABBITMQ_DEFAULT_PASS: "guest"

poetry.lock

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

taskiq_aio_pika/broker.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__( # noqa: WPS211
5353
routing_key: str = "#",
5454
exchange_type: ExchangeType = ExchangeType.TOPIC,
5555
max_priority: Optional[int] = None,
56+
delayed_message_exchange_plugin: bool = False,
5657
**connection_kwargs: Any,
5758
) -> None:
5859
"""
@@ -79,6 +80,8 @@ def __init__( # noqa: WPS211
7980
:param exchange_type: type of the exchange.
8081
Used only if `declare_exchange` is True.
8182
:param max_priority: maximum priority value for messages.
83+
:param delayed_message_exchange_plugin: turn on or disable
84+
delayed-message-exchange rabbitmq plugin.
8285
:param connection_kwargs: additional keyword arguments,
8386
for connect_robust method of aio-pika.
8487
"""
@@ -95,6 +98,7 @@ def __init__( # noqa: WPS211
9598
self._queue_name = queue_name
9699
self._routing_key = routing_key
97100
self._max_priority = max_priority
101+
self._delayed_message_exchange_plugin = delayed_message_exchange_plugin
98102

99103
self._dead_letter_queue_name = f"{queue_name}.dead_letter"
100104
if dead_letter_queue_name:
@@ -104,6 +108,8 @@ def __init__( # noqa: WPS211
104108
if delay_queue_name:
105109
self._delay_queue_name = delay_queue_name
106110

111+
self._delay_plugin_exchange_name = f"{exchange_name}.plugin_delay"
112+
107113
self.read_conn: Optional[AbstractRobustConnection] = None
108114
self.write_conn: Optional[AbstractRobustConnection] = None
109115
self.write_channel: Optional[AbstractChannel] = None
@@ -132,9 +138,31 @@ async def startup(self) -> None: # noqa: WPS217
132138
self._exchange_name,
133139
type=self._exchange_type,
134140
)
141+
142+
if self._delayed_message_exchange_plugin:
143+
await self.write_channel.declare_exchange(
144+
self._delay_plugin_exchange_name,
145+
type=ExchangeType.X_DELAYED_MESSAGE,
146+
arguments={
147+
"x-delayed-type": "direct",
148+
},
149+
)
150+
135151
if self._declare_queues:
136152
await self.declare_queues(self.write_channel)
137153

154+
async def shutdown(self) -> None:
155+
"""Close all connections on shutdown."""
156+
await super().shutdown()
157+
if self.write_channel:
158+
await self.write_channel.close()
159+
if self.read_channel:
160+
await self.read_channel.close()
161+
if self.write_conn:
162+
await self.write_conn.close()
163+
if self.read_conn:
164+
await self.read_conn.close()
165+
138166
async def declare_queues(
139167
self,
140168
channel: AbstractChannel,
@@ -163,14 +191,24 @@ async def declare_queues(
163191
self._queue_name,
164192
arguments=args,
165193
)
166-
await channel.declare_queue(
167-
self._delay_queue_name,
168-
arguments={
169-
"x-dead-letter-exchange": "",
170-
"x-dead-letter-routing-key": self._queue_name,
171-
},
194+
if self._delayed_message_exchange_plugin:
195+
await queue.bind(
196+
exchange=self._delay_plugin_exchange_name,
197+
routing_key=self._routing_key,
198+
)
199+
else:
200+
await channel.declare_queue(
201+
self._delay_queue_name,
202+
arguments={
203+
"x-dead-letter-exchange": "",
204+
"x-dead-letter-routing-key": self._queue_name,
205+
},
206+
)
207+
208+
await queue.bind(
209+
exchange=self._exchange_name,
210+
routing_key=self._routing_key,
172211
)
173-
await queue.bind(exchange=self._exchange_name, routing_key=self._routing_key)
174212
return queue
175213

176214
async def kick(self, message: BrokerMessage) -> None:
@@ -189,30 +227,47 @@ async def kick(self, message: BrokerMessage) -> None:
189227
"""
190228
if self.write_channel is None:
191229
raise ValueError("Please run startup before kicking.")
192-
priority = parse_val(int, message.labels.get("priority"))
193-
rmq_msg = Message(
194-
body=message.message,
195-
headers={
230+
231+
message_base_params: dict[str, Any] = {
232+
"body": message.message,
233+
"headers": {
196234
"task_id": message.task_id,
197235
"task_name": message.task_name,
198236
**message.labels,
199237
},
200-
delivery_mode=DeliveryMode.PERSISTENT,
201-
priority=priority,
238+
"delivery_mode": DeliveryMode.PERSISTENT,
239+
}
240+
241+
message_base_params["priority"] = parse_val(
242+
int,
243+
message.labels.get("priority"),
202244
)
203-
delay = parse_val(int, message.labels.get("delay"))
245+
246+
delay: Optional[int] = parse_val(int, message.labels.get("delay"))
247+
rmq_message: Message = Message(**message_base_params)
248+
204249
if delay is None:
205250
exchange = await self.write_channel.get_exchange(
206251
self._exchange_name,
207252
ensure=False,
208253
)
209-
await exchange.publish(rmq_msg, routing_key=message.task_name)
254+
await exchange.publish(rmq_message, routing_key=message.task_name)
210255
else:
211-
rmq_msg.expiration = timedelta(seconds=delay)
212-
await self.write_channel.default_exchange.publish(
213-
rmq_msg,
214-
routing_key=self._delay_queue_name,
215-
)
256+
if self._delayed_message_exchange_plugin:
257+
rmq_message.headers["x-delay"] = delay * 1000
258+
exchange = await self.write_channel.get_exchange(
259+
self._delay_plugin_exchange_name,
260+
)
261+
await exchange.publish(
262+
rmq_message,
263+
routing_key=self._routing_key,
264+
)
265+
else:
266+
rmq_message.expiration = timedelta(seconds=delay)
267+
await self.write_channel.default_exchange.publish(
268+
rmq_message,
269+
routing_key=self._delay_queue_name,
270+
)
216271

217272
async def listen(self) -> AsyncGenerator[bytes, None]:
218273
"""
@@ -232,15 +287,3 @@ async def listen(self) -> AsyncGenerator[bytes, None]:
232287
async for message in iterator:
233288
async with message.process():
234289
yield message.body
235-
236-
async def shutdown(self) -> None:
237-
"""Close all connections on shutdown."""
238-
await super().shutdown()
239-
if self.write_channel:
240-
await self.write_channel.close()
241-
if self.read_channel:
242-
await self.read_channel.close()
243-
if self.write_conn:
244-
await self.write_conn.close()
245-
if self.read_conn:
246-
await self.read_conn.close()

tests/conftest.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ def queue_name() -> str:
4242
return uuid4().hex
4343

4444

45+
@pytest.fixture
46+
def routing_key() -> str:
47+
"""
48+
Generated routing key.
49+
50+
:return: random routing key.
51+
"""
52+
return uuid4().hex
53+
54+
4555
@pytest.fixture
4656
def delay_queue_name() -> str:
4757
"""
@@ -155,3 +165,67 @@ async def broker(
155165
if_empty=False,
156166
if_unused=False,
157167
)
168+
169+
170+
@pytest.fixture
171+
async def broker_with_delayed_message_plugin(
172+
amqp_url: str,
173+
queue_name: str,
174+
delay_queue_name: str,
175+
dead_queue_name: str,
176+
exchange_name: str,
177+
routing_key: str,
178+
test_channel: Channel,
179+
) -> AsyncGenerator[AioPikaBroker, None]:
180+
"""
181+
Yields new broker instance.
182+
183+
This function is used to
184+
create broker, run startup,
185+
and shutdown after test.
186+
187+
:param amqp_url: current rabbitmq connection string.
188+
:param test_channel: amqp channel for tests.
189+
:param queue_name: test queue name.
190+
:param delay_queue_name: test delay queue name.
191+
:param dead_queue_name: test dead letter queue name.
192+
:param exchange_name: test exchange name.
193+
:param routing_key: routing_key.
194+
:yield: broker.
195+
"""
196+
broker = AioPikaBroker(
197+
url=amqp_url,
198+
declare_exchange=True,
199+
exchange_name=exchange_name,
200+
dead_letter_queue_name=dead_queue_name,
201+
queue_name=queue_name,
202+
delayed_message_exchange_plugin=True,
203+
routing_key=routing_key,
204+
)
205+
broker.is_worker_process = True
206+
207+
await broker.startup()
208+
209+
yield broker
210+
211+
await broker.shutdown()
212+
213+
exchange = await test_channel.get_exchange(exchange_name)
214+
await exchange.delete(
215+
timeout=1,
216+
if_unused=False,
217+
)
218+
plugin_exchange = await test_channel.get_exchange(
219+
broker._delay_plugin_exchange_name,
220+
)
221+
await plugin_exchange.delete(
222+
timeout=1,
223+
if_unused=False,
224+
)
225+
for i_queue_name in (queue_name, delay_queue_name, dead_queue_name):
226+
queue = await test_channel.get_queue(i_queue_name, ensure=False)
227+
await queue.delete(
228+
timeout=1,
229+
if_empty=False,
230+
if_unused=False,
231+
)

tests/test_broker.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,33 @@ async def test_delayed_message(
185185

186186
# Check that we can get the message.
187187
await main_queue.get()
188+
189+
190+
@pytest.mark.anyio
191+
async def test_delayed_message_with_plugin(
192+
broker_with_delayed_message_plugin: AioPikaBroker,
193+
test_channel: Channel,
194+
queue_name: str,
195+
) -> None:
196+
"""Test that we can send delayed messages with plugin.
197+
198+
:param broker_with_delayed_message_plugin: broker with
199+
turned on plugin integration.
200+
:param test_channel: amqp channel for tests.
201+
:param queue_name: test queue name.
202+
"""
203+
main_queue = await test_channel.get_queue(queue_name)
204+
broker_msg = BrokerMessage(
205+
task_id="1",
206+
task_name="name",
207+
message=b"message",
208+
labels={"delay": "2"},
209+
)
210+
211+
await broker_with_delayed_message_plugin.kick(broker_msg)
212+
with pytest.raises(QueueEmpty):
213+
await main_queue.get(no_ack=True)
214+
215+
await asyncio.sleep(2)
216+
217+
assert await main_queue.get()

0 commit comments

Comments
 (0)