@@ -46,13 +46,15 @@ def __init__( # noqa: WPS211
4646 validate_params : bool = True ,
4747 max_async_tasks : "Optional[int]" = None ,
4848 max_prefetch : int = 0 ,
49+ propagate_exceptions : bool = True ,
4950 ) -> None :
5051 self .broker = broker
5152 self .executor = executor
5253 self .validate_params = validate_params
5354 self .task_signatures : Dict [str , inspect .Signature ] = {}
5455 self .task_hints : Dict [str , Dict [str , Any ]] = {}
5556 self .dependency_graphs : Dict [str , DependencyGraph ] = {}
57+ self .propagate_exceptions = propagate_exceptions
5658 for task in self .broker .available_tasks .values ():
5759 self .task_signatures [task .task_name ] = inspect .signature (task .original_func )
5860 self .task_hints [task .task_name ] = get_type_hints (task .original_func )
@@ -213,7 +215,14 @@ async def run_task( # noqa: C901, WPS210
213215 # Stop the timer.
214216 execution_time = time () - start_time
215217 if dep_ctx :
216- await dep_ctx .close ()
218+ args = (None , None , None )
219+ if found_exception and self .propagate_exceptions :
220+ args = ( # type: ignore
221+ type (found_exception ),
222+ found_exception ,
223+ found_exception .__traceback__ ,
224+ )
225+ await dep_ctx .close (* args )
217226
218227 # Assemble result.
219228 result : "TaskiqResult[Any]" = TaskiqResult (
0 commit comments