1010
1111from taskiq .abc .broker import AckableMessage , AsyncBroker
1212from taskiq .abc .middleware import TaskiqMiddleware
13+ from taskiq .acks import AcknowledgeType
1314from taskiq .context import Context
1415from taskiq .exceptions import NoResultError
1516from taskiq .message import TaskiqMessage
@@ -53,6 +54,7 @@ def __init__(
5354 max_prefetch : int = 0 ,
5455 propagate_exceptions : bool = True ,
5556 run_starup : bool = True ,
57+ ack_type : Optional [AcknowledgeType ] = None ,
5658 on_exit : Optional [Callable [["Receiver" ], None ]] = None ,
5759 ) -> None :
5860 self .broker = broker
@@ -64,6 +66,7 @@ def __init__(
6466 self .dependency_graphs : Dict [str , DependencyGraph ] = {}
6567 self .propagate_exceptions = propagate_exceptions
6668 self .on_exit = on_exit
69+ self .ack_time = ack_type or AcknowledgeType .WHEN_SAVED
6770 self .known_tasks : Set [str ] = set ()
6871 for task in self .broker .get_all_tasks ().values ():
6972 self ._prepare_task (task .task_name , task .original_func )
@@ -131,13 +134,21 @@ async def callback( # noqa: C901, PLR0912
131134 taskiq_msg .task_id ,
132135 )
133136
137+ if self .ack_time == AcknowledgeType .WHEN_RECEIVED and isinstance (
138+ message ,
139+ AckableMessage ,
140+ ):
141+ await maybe_awaitable (message .ack ())
142+
134143 result = await self .run_task (
135144 target = task .original_func ,
136145 message = taskiq_msg ,
137146 )
138147
139- # If broker has an ability to ack messages.
140- if isinstance (message , AckableMessage ):
148+ if self .ack_time == AcknowledgeType .WHEN_EXECUTED and isinstance (
149+ message ,
150+ AckableMessage ,
151+ ):
141152 await maybe_awaitable (message .ack ())
142153
143154 for middleware in self .broker .middlewares :
@@ -147,9 +158,11 @@ async def callback( # noqa: C901, PLR0912
147158 try :
148159 if not isinstance (result .error , NoResultError ):
149160 await self .broker .result_backend .set_result (taskiq_msg .task_id , result )
161+
150162 for middleware in self .broker .middlewares :
151163 if middleware .__class__ .post_save != TaskiqMiddleware .post_save :
152164 await maybe_awaitable (middleware .post_save (taskiq_msg , result ))
165+
153166 except Exception as exc :
154167 logger .exception (
155168 "Can't set result in result backend. Cause: %s" ,
@@ -159,6 +172,12 @@ async def callback( # noqa: C901, PLR0912
159172 if raise_err :
160173 raise exc
161174
175+ if self .ack_time == AcknowledgeType .WHEN_SAVED and isinstance (
176+ message ,
177+ AckableMessage ,
178+ ):
179+ await maybe_awaitable (message .ack ())
180+
162181 async def run_task ( # noqa: C901, PLR0912, PLR0915
163182 self ,
164183 target : Callable [..., Any ],
0 commit comments