Skip to content

Commit a9e63a3

Browse files
authored
feat: add run_in_parallel parameter to input guardrails (#1986)
1 parent 48164ec commit a9e63a3

File tree

3 files changed

+1220
-7
lines changed

3 files changed

+1220
-7
lines changed

src/agents/guardrail.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class OutputGuardrailResult:
7070

7171
@dataclass
7272
class InputGuardrail(Generic[TContext]):
73-
"""Input guardrails are checks that run in parallel to the agent's execution.
73+
"""Input guardrails are checks that run either in parallel with the agent or before it starts.
7474
They can be used to do things like:
7575
- Check if input messages are off-topic
7676
- Take over control of the agent's execution if an unexpected input is detected
@@ -97,6 +97,11 @@ class InputGuardrail(Generic[TContext]):
9797
function's name.
9898
"""
9999

100+
run_in_parallel: bool = True
101+
"""Whether the guardrail runs concurrently with the agent (True, default) or before
102+
the agent starts (False).
103+
"""
104+
100105
def get_name(self) -> str:
101106
if self.name:
102107
return self.name
@@ -209,6 +214,7 @@ def input_guardrail(
209214
def input_guardrail(
210215
*,
211216
name: str | None = None,
217+
run_in_parallel: bool = True,
212218
) -> Callable[
213219
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
214220
InputGuardrail[TContext_co],
@@ -221,6 +227,7 @@ def input_guardrail(
221227
| None = None,
222228
*,
223229
name: str | None = None,
230+
run_in_parallel: bool = True,
224231
) -> (
225232
InputGuardrail[TContext_co]
226233
| Callable[
@@ -235,8 +242,14 @@ def input_guardrail(
235242
@input_guardrail
236243
def my_sync_guardrail(...): ...
237244
238-
@input_guardrail(name="guardrail_name")
245+
@input_guardrail(name="guardrail_name", run_in_parallel=False)
239246
async def my_async_guardrail(...): ...
247+
248+
Args:
249+
func: The guardrail function to wrap.
250+
name: Optional name for the guardrail. If not provided, uses the function's name.
251+
run_in_parallel: Whether to run the guardrail concurrently with the agent (True, default)
252+
or before the agent starts (False).
240253
"""
241254

242255
def decorator(
@@ -246,6 +259,7 @@ def decorator(
246259
guardrail_function=f,
247260
# If not set, guardrail name uses the function’s name by default.
248261
name=name if name else f.__name__,
262+
run_in_parallel=run_in_parallel,
249263
)
250264

251265
if func is not None:

src/agents/run.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,11 +614,31 @@ async def run(
614614
)
615615

616616
if current_turn == 1:
617+
# Separate guardrails based on execution mode.
618+
all_input_guardrails = starting_agent.input_guardrails + (
619+
run_config.input_guardrails or []
620+
)
621+
sequential_guardrails = [
622+
g for g in all_input_guardrails if not g.run_in_parallel
623+
]
624+
parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel]
625+
626+
# Run blocking guardrails first, before agent starts.
627+
# (will raise exception if tripwire triggered).
628+
sequential_results = []
629+
if sequential_guardrails:
630+
sequential_results = await self._run_input_guardrails(
631+
starting_agent,
632+
sequential_guardrails,
633+
_copy_str_or_list(prepared_input),
634+
context_wrapper,
635+
)
636+
637+
# Run parallel guardrails + agent together.
617638
input_guardrail_results, turn_result = await asyncio.gather(
618639
self._run_input_guardrails(
619640
starting_agent,
620-
starting_agent.input_guardrails
621-
+ (run_config.input_guardrails or []),
641+
parallel_guardrails,
622642
_copy_str_or_list(prepared_input),
623643
context_wrapper,
624644
),
@@ -635,6 +655,9 @@ async def run(
635655
server_conversation_tracker=server_conversation_tracker,
636656
),
637657
)
658+
659+
# Combine sequential and parallel results.
660+
input_guardrail_results = sequential_results + input_guardrail_results
638661
else:
639662
turn_result = await self._run_single_turn(
640663
agent=current_agent,
@@ -954,6 +977,11 @@ async def _run_input_guardrails_with_queue(
954977
for done in asyncio.as_completed(guardrail_tasks):
955978
result = await done
956979
if result.output.tripwire_triggered:
980+
# Cancel all remaining guardrail tasks if a tripwire is triggered.
981+
for t in guardrail_tasks:
982+
t.cancel()
983+
# Wait for cancellations to propagate by awaiting the cancelled tasks.
984+
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
957985
_error_tracing.attach_error_to_span(
958986
parent_span,
959987
SpanError(
@@ -964,14 +992,19 @@ async def _run_input_guardrails_with_queue(
964992
},
965993
),
966994
)
995+
queue.put_nowait(result)
996+
guardrail_results.append(result)
997+
break
967998
queue.put_nowait(result)
968999
guardrail_results.append(result)
9691000
except Exception:
9701001
for t in guardrail_tasks:
9711002
t.cancel()
9721003
raise
9731004

974-
streamed_result.input_guardrail_results = guardrail_results
1005+
streamed_result.input_guardrail_results = (
1006+
streamed_result.input_guardrail_results + guardrail_results
1007+
)
9751008

9761009
@classmethod
9771010
async def _start_streaming(
@@ -1063,11 +1096,36 @@ async def _start_streaming(
10631096
break
10641097

10651098
if current_turn == 1:
1066-
# Run the input guardrails in the background and put the results on the queue
1099+
# Separate guardrails based on execution mode.
1100+
all_input_guardrails = starting_agent.input_guardrails + (
1101+
run_config.input_guardrails or []
1102+
)
1103+
sequential_guardrails = [
1104+
g for g in all_input_guardrails if not g.run_in_parallel
1105+
]
1106+
parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel]
1107+
1108+
# Run sequential guardrails first.
1109+
if sequential_guardrails:
1110+
await cls._run_input_guardrails_with_queue(
1111+
starting_agent,
1112+
sequential_guardrails,
1113+
ItemHelpers.input_to_new_input_list(prepared_input),
1114+
context_wrapper,
1115+
streamed_result,
1116+
current_span,
1117+
)
1118+
# Check if any blocking guardrail triggered and raise before starting agent.
1119+
for result in streamed_result.input_guardrail_results:
1120+
if result.output.tripwire_triggered:
1121+
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
1122+
raise InputGuardrailTripwireTriggered(result)
1123+
1124+
# Run parallel guardrails in background.
10671125
streamed_result._input_guardrails_task = asyncio.create_task(
10681126
cls._run_input_guardrails_with_queue(
10691127
starting_agent,
1070-
starting_agent.input_guardrails + (run_config.input_guardrails or []),
1128+
parallel_guardrails,
10711129
ItemHelpers.input_to_new_input_list(prepared_input),
10721130
context_wrapper,
10731131
streamed_result,
@@ -1632,6 +1690,8 @@ async def _run_input_guardrails(
16321690
# Cancel all guardrail tasks if a tripwire is triggered.
16331691
for t in guardrail_tasks:
16341692
t.cancel()
1693+
# Wait for cancellations to propagate by awaiting the cancelled tasks.
1694+
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
16351695
_error_tracing.attach_error_to_current_span(
16361696
SpanError(
16371697
message="Guardrail tripwire triggered",

0 commit comments

Comments
 (0)