@@ -614,11 +614,31 @@ async def run(
614614 )
615615
616616 if current_turn == 1 :
617+ # Separate guardrails based on execution mode.
618+ all_input_guardrails = starting_agent .input_guardrails + (
619+ run_config .input_guardrails or []
620+ )
621+ sequential_guardrails = [
622+ g for g in all_input_guardrails if not g .run_in_parallel
623+ ]
624+ parallel_guardrails = [g for g in all_input_guardrails if g .run_in_parallel ]
625+
626+ # Run blocking guardrails first, before agent starts.
627+ # (will raise exception if tripwire triggered).
628+ sequential_results = []
629+ if sequential_guardrails :
630+ sequential_results = await self ._run_input_guardrails (
631+ starting_agent ,
632+ sequential_guardrails ,
633+ _copy_str_or_list (prepared_input ),
634+ context_wrapper ,
635+ )
636+
637+ # Run parallel guardrails + agent together.
617638 input_guardrail_results , turn_result = await asyncio .gather (
618639 self ._run_input_guardrails (
619640 starting_agent ,
620- starting_agent .input_guardrails
621- + (run_config .input_guardrails or []),
641+ parallel_guardrails ,
622642 _copy_str_or_list (prepared_input ),
623643 context_wrapper ,
624644 ),
@@ -635,6 +655,9 @@ async def run(
635655 server_conversation_tracker = server_conversation_tracker ,
636656 ),
637657 )
658+
659+ # Combine sequential and parallel results.
660+ input_guardrail_results = sequential_results + input_guardrail_results
638661 else :
639662 turn_result = await self ._run_single_turn (
640663 agent = current_agent ,
@@ -954,6 +977,11 @@ async def _run_input_guardrails_with_queue(
954977 for done in asyncio .as_completed (guardrail_tasks ):
955978 result = await done
956979 if result .output .tripwire_triggered :
980+ # Cancel all remaining guardrail tasks if a tripwire is triggered.
981+ for t in guardrail_tasks :
982+ t .cancel ()
983+ # Wait for cancellations to propagate by awaiting the cancelled tasks.
984+ await asyncio .gather (* guardrail_tasks , return_exceptions = True )
957985 _error_tracing .attach_error_to_span (
958986 parent_span ,
959987 SpanError (
@@ -964,14 +992,19 @@ async def _run_input_guardrails_with_queue(
964992 },
965993 ),
966994 )
995+ queue .put_nowait (result )
996+ guardrail_results .append (result )
997+ break
967998 queue .put_nowait (result )
968999 guardrail_results .append (result )
9691000 except Exception :
9701001 for t in guardrail_tasks :
9711002 t .cancel ()
9721003 raise
9731004
974- streamed_result .input_guardrail_results = guardrail_results
1005+ streamed_result .input_guardrail_results = (
1006+ streamed_result .input_guardrail_results + guardrail_results
1007+ )
9751008
9761009 @classmethod
9771010 async def _start_streaming (
@@ -1063,11 +1096,36 @@ async def _start_streaming(
10631096 break
10641097
10651098 if current_turn == 1 :
1066- # Run the input guardrails in the background and put the results on the queue
1099+ # Separate guardrails based on execution mode.
1100+ all_input_guardrails = starting_agent .input_guardrails + (
1101+ run_config .input_guardrails or []
1102+ )
1103+ sequential_guardrails = [
1104+ g for g in all_input_guardrails if not g .run_in_parallel
1105+ ]
1106+ parallel_guardrails = [g for g in all_input_guardrails if g .run_in_parallel ]
1107+
1108+ # Run sequential guardrails first.
1109+ if sequential_guardrails :
1110+ await cls ._run_input_guardrails_with_queue (
1111+ starting_agent ,
1112+ sequential_guardrails ,
1113+ ItemHelpers .input_to_new_input_list (prepared_input ),
1114+ context_wrapper ,
1115+ streamed_result ,
1116+ current_span ,
1117+ )
1118+ # Check if any blocking guardrail triggered and raise before starting agent.
1119+ for result in streamed_result .input_guardrail_results :
1120+ if result .output .tripwire_triggered :
1121+ streamed_result ._event_queue .put_nowait (QueueCompleteSentinel ())
1122+ raise InputGuardrailTripwireTriggered (result )
1123+
1124+ # Run parallel guardrails in background.
10671125 streamed_result ._input_guardrails_task = asyncio .create_task (
10681126 cls ._run_input_guardrails_with_queue (
10691127 starting_agent ,
1070- starting_agent . input_guardrails + ( run_config . input_guardrails or []) ,
1128+ parallel_guardrails ,
10711129 ItemHelpers .input_to_new_input_list (prepared_input ),
10721130 context_wrapper ,
10731131 streamed_result ,
@@ -1632,6 +1690,8 @@ async def _run_input_guardrails(
16321690 # Cancel all guardrail tasks if a tripwire is triggered.
16331691 for t in guardrail_tasks :
16341692 t .cancel ()
1693+ # Wait for cancellations to propagate by awaiting the cancelled tasks.
1694+ await asyncio .gather (* guardrail_tasks , return_exceptions = True )
16351695 _error_tracing .attach_error_to_current_span (
16361696 SpanError (
16371697 message = "Guardrail tripwire triggered" ,
0 commit comments