Skip to content

Commit 5fd94be

Browse files
authored
feat: add labels to TaskiqAdminMiddleware requests (#532)
1 parent 7c92e16 commit 5fd94be

File tree

3 files changed

+25
-5
lines changed

3 files changed

+25
-5
lines changed

taskiq/middlewares/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from .prometheus_middleware import PrometheusMiddleware
44
from .simple_retry_middleware import SimpleRetryMiddleware
55
from .smart_retry_middleware import SmartRetryMiddleware
6+
from .taskiq_admin_middleware import TaskiqAdminMiddleware
67

78
__all__ = (
89
"PrometheusMiddleware",
910
"SimpleRetryMiddleware",
1011
"SmartRetryMiddleware",
12+
"TaskiqAdminMiddleware",
1113
)

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import aiohttp
88

9-
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
9+
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.message import TaskiqMessage
11+
from taskiq.result import TaskiqResult
1012

1113
__all__ = ("TaskiqAdminMiddleware",)
1214

@@ -118,6 +120,7 @@ async def post_send(self, message: TaskiqMessage) -> None:
118120
{
119121
"args": message.args,
120122
"kwargs": message.kwargs,
123+
"labels": message.labels,
121124
"queuedAt": self._now_iso(),
122125
"taskName": message.task_name,
123126
"worker": self.__ta_broker_name,
@@ -139,6 +142,7 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
139142
{
140143
"args": message.args,
141144
"kwargs": message.kwargs,
145+
"labels": message.labels,
142146
"startedAt": self._now_iso(),
143147
"taskName": message.task_name,
144148
"worker": self.__ta_broker_name,

tests/middlewares/test_taskiq_admin_middleware.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import datetime
23
from typing import AsyncGenerator
34
from unittest.mock import AsyncMock, Mock, patch
45

@@ -26,7 +27,19 @@ def message() -> TaskiqMessage:
2627
return TaskiqMessage(
2728
task_id="task-123",
2829
task_name="test_task",
29-
labels={},
30+
labels={
31+
"schedule": {
32+
"cron": "*/1 * * * *",
33+
"cron_offset": datetime.timedelta(hours=1),
34+
"time": datetime.datetime.now(datetime.timezone.utc),
35+
"labels": {
36+
"test_bool": True,
37+
"test_int": 1,
38+
"test_str": "str",
39+
"test_bytes": b"bytes",
40+
},
41+
},
42+
},
3043
args=[1, 2, 3],
3144
kwargs={"key": "value"},
3245
)
@@ -80,8 +93,9 @@ async def test_when_post_send_is_called__then_payload_includes_task_info(
8093
call_args = mock_post.call_args
8194
assert call_args is not None
8295
payload = call_args[1]["json"]
83-
assert payload["args"] == [1, 2, 3]
84-
assert payload["kwargs"] == {"key": "value"}
85-
assert payload["taskName"] == "test_task"
96+
assert payload["args"] == message.args
97+
assert payload["kwargs"] == message.kwargs
98+
assert payload["taskName"] == message.task_name
8699
assert payload["worker"] == "test-broker"
100+
assert payload["labels"] == message.labels
87101
assert "queuedAt" in payload

0 commit comments

Comments
 (0)